{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.stats as stats\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pickle\n",
    "\n",
    "import random\n",
    "import os\n",
    "\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.layers as L\n",
    "import tensorflow.keras.models as M\n",
    "import tensorflow.keras.backend as K\n",
    "import tensorflow_addons as tfa\n",
    "from tensorflow_addons.layers import WeightNormalization\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping\n",
    "pd.options.mode.chained_assignment = None\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.python.client import device_lib\n",
    "def get_available_gpus():\n",
    "    local_device_protos = device_lib.list_local_devices()\n",
    "    return [x.name for x in local_device_protos if x.device_type == 'GPU']\n",
    "print(get_available_gpus())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>site_path_timestamp</th>\n",
       "      <th>site</th>\n",
       "      <th>path</th>\n",
       "      <th>ts_waypoint</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>9017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>15326</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>18763</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>22328</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 site_path_timestamp  \\\n",
       "0  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "1  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "2  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "3  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "4  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "\n",
       "                       site                      path  ts_waypoint  \n",
       "0  5a0546857ecc773753327266  046cfa46be49fc10834815c6            9  \n",
       "1  5a0546857ecc773753327266  046cfa46be49fc10834815c6         9017  \n",
       "2  5a0546857ecc773753327266  046cfa46be49fc10834815c6        15326  \n",
       "3  5a0546857ecc773753327266  046cfa46be49fc10834815c6        18763  \n",
       "4  5a0546857ecc773753327266  046cfa46be49fc10834815c6        22328  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# PATH = '../input/indoor-location-navigation'\n",
    "# test_files = glob.glob(f'{PATH}/test/*.txt')\n",
    "# test_files_pd = [xx.split('/')[-1:][0].replace('.txt','') for xx in test_files]\n",
    "# test_files_pd = pd.DataFrame(test_files_pd)\n",
    "# test_files_pd.columns = ['path']\n",
    "\n",
    "sample_submission = pd.read_csv(\"../input/indoor-location-navigation/sample_submission.csv\")\n",
    "sample_submission['site'] = [xx.split('_')[0] for xx in sample_submission.site_path_timestamp]\n",
    "sample_submission['path'] = [xx.split('_')[1] for xx in sample_submission.site_path_timestamp]\n",
    "sample_submission['ts_waypoint'] = [int(xx.split('_')[2]) for xx in sample_submission.site_path_timestamp]\n",
    "del sample_submission['floor']\n",
    "del sample_submission['x']\n",
    "del sample_submission['y']\n",
    "\n",
    "path2site = dict(zip(sample_submission.path,sample_submission.site))\n",
    "sample_submission.head()\n",
    "# test_path_site = sample_submission[['site','path','timestamp','site_path_timestamp']]\n",
    "# test_files_pd = pd.merge(test_files_pd,test_path_site,how='left',on='path')\n",
    "# test_files_pd.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_wifi_files = glob.glob(f'../input/wifi_lbl_encode/test/*.txt')\n",
    "test_sensor_files = glob.glob('../input/data_abstract/*_test_sensor.csv')\n",
    "\n",
    "# train_files = glob.glob('../input/indoor-navigation-and-location-wifi-features-alldata/*train.csv') #if A \n",
    "train_files = glob.glob('../input/data_abstract/*_train_waypoint_all.csv')#if B\n",
    "\n",
    "    \n",
    "train_wifi_files = glob.glob(f'../input/wifi_lbl_encode/train/*/*/*.txt')\n",
    "train_sensor_files = glob.glob('../input/data_abstract/*_train_sensor.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../input/data_abstract/5a0546857ecc773753327266_train_waypoint_all.csv',\n",
       " '../input/data_abstract/5c3c44b80379370013e0fd2b_train_waypoint_all.csv']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_files[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len train site list: 24\n"
     ]
    }
   ],
   "source": [
    "# train_site_list = [xx.split('/')[-1].replace('_train.csv','') for xx in train_files] #if A \n",
    "# train_site_list = [xx.split('/')[-1].replace('_train_waypoint_all.csv','') for xx in train_files] #if B 204\n",
    "train_site_list = list(sample_submission.site.unique()) # if B 24\n",
    "train_wifi_files = [xx for xx in train_wifi_files if xx.split('/')[-3] in train_site_list]\n",
    "print('len train site list:',len(train_site_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10877"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_wifi_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11503/11503 [01:01<00:00, 187.96it/s]\n"
     ]
    }
   ],
   "source": [
    "ssidlist = set()\n",
    "bssidlist = set()\n",
    "for filename in tqdm(train_wifi_files+test_wifi_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    ssidlist = ssidlist|set(tmp.ssid)\n",
    "    bssidlist = bssidlist|set(tmp.bssid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20044, 65952)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(set(ssidlist)),len(set(bssidlist))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "seqlen = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ssiddict = dict(zip(list(ssidlist)+['empty'],range(len(ssidlist)+1)))\n",
    "bssiddict = dict(zip(list(bssidlist)+['empty'],range(len(bssidlist)+1)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10877/10877 [00:43<00:00, 252.76it/s]\n"
     ]
    }
   ],
   "source": [
    "train_wifi_pd_csv = []\n",
    "for filename in tqdm(train_wifi_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    tmp['path'] = filename.split('/')[-1].replace('.txt','')\n",
    "    tmp['floor'] = filename.split('/')[-2]\n",
    "    tmp['site'] = filename.split('/')[-3]\n",
    "    train_wifi_pd_csv.append(tmp)\n",
    "train_wifi_pd_csv = pd.concat(train_wifi_pd_csv).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/24 [00:00<?, ?it/s]/home/ec2-user/anaconda3/lib/python3.7/site-packages/numpy/lib/arraysetops.py:569: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  mask |= (ar1 == a)\n",
      "100%|██████████| 24/24 [01:02<00:00,  2.59s/it]\n"
     ]
    }
   ],
   "source": [
    "train_sensor_pd_csv = []\n",
    "for filename in tqdm(train_sensor_files):\n",
    "    tmp = pd.read_csv(filename,index_col=0)\n",
    "    train_sensor_pd_csv.append(tmp)\n",
    "train_sensor_pd_csv = pd.concat(train_sensor_pd_csv).reset_index(drop=True)\n",
    "train_sensor_pd_csv['magne'] = train_sensor_pd_csv['x_magne']**2+\\\n",
    "                                train_sensor_pd_csv['y_magne']**2+train_sensor_pd_csv['z_magne']**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ts_sensor</th>\n",
       "      <th>x_acce</th>\n",
       "      <th>y_acce</th>\n",
       "      <th>z_acce</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>floor</th>\n",
       "      <th>floor_ori</th>\n",
       "      <th>magne</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.578463e+12</td>\n",
       "      <td>0.023697</td>\n",
       "      <td>4.450943</td>\n",
       "      <td>9.055649</td>\n",
       "      <td>-0.037537</td>\n",
       "      <td>0.075256</td>\n",
       "      <td>0.030579</td>\n",
       "      <td>-13.391113</td>\n",
       "      <td>9.959412</td>\n",
       "      <td>-30.305481</td>\n",
       "      <td>0.227164</td>\n",
       "      <td>-0.058094</td>\n",
       "      <td>-0.268773</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "      <td>B1</td>\n",
       "      <td>0.008008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.578463e+12</td>\n",
       "      <td>0.050629</td>\n",
       "      <td>4.552109</td>\n",
       "      <td>9.074799</td>\n",
       "      <td>-0.043411</td>\n",
       "      <td>-0.005722</td>\n",
       "      <td>0.009796</td>\n",
       "      <td>-12.002563</td>\n",
       "      <td>9.959412</td>\n",
       "      <td>-28.955078</td>\n",
       "      <td>0.225032</td>\n",
       "      <td>-0.059640</td>\n",
       "      <td>-0.267238</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "      <td>B1</td>\n",
       "      <td>0.002013</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      ts_sensor    x_acce    y_acce    z_acce   x_magne   y_magne   z_magne  \\\n",
       "0  1.578463e+12  0.023697  4.450943  9.055649 -0.037537  0.075256  0.030579   \n",
       "1  1.578463e+12  0.050629  4.552109  9.074799 -0.043411 -0.005722  0.009796   \n",
       "\n",
       "     x_gyros   y_gyros    z_gyros  x_rotate  y_rotate  z_rotate  \\\n",
       "0 -13.391113  9.959412 -30.305481  0.227164 -0.058094 -0.268773   \n",
       "1 -12.002563  9.959412 -28.955078  0.225032 -0.059640 -0.267238   \n",
       "\n",
       "                       path                      site  floor floor_ori  \\\n",
       "0  5e15730aa280850006f3d005  5a0546857ecc773753327266     -1        B1   \n",
       "1  5e15730aa280850006f3d005  5a0546857ecc773753327266     -1        B1   \n",
       "\n",
       "      magne  \n",
       "0  0.008008  \n",
       "1  0.002013  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_sensor_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "floor_map = {\"B3\":-3,\"B2\":-2, \"B1\":-1, \"F1\":0, \"F2\": 1, \"F3\":2, \"F4\":3, \"F5\":4, \"F6\":5, \"F7\":6,\"F8\":7, \"F9\":8,\n",
    "             \"1F\":0, \"2F\":1, \"3F\":2, \"4F\":3, \"5F\":4, \"6F\":5, \"7F\":6, \"8F\": 7, \"9F\":8}\n",
    "train_wifi_pd_csv = train_wifi_pd_csv[train_wifi_pd_csv.floor.isin(floor_map)].reset_index(drop=True)\n",
    "train_wifi_pd_csv['floorNo'] = train_wifi_pd_csv['floor'].apply(lambda x: floor_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>last_timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>floorNo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>63159</td>\n",
       "      <td>162932</td>\n",
       "      <td>-46</td>\n",
       "      <td>1578462603277</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>32835</td>\n",
       "      <td>65513</td>\n",
       "      <td>-49</td>\n",
       "      <td>1578462618272</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp   ssid   bssid  rssi  last_timestamp  \\\n",
       "0  1578462618826  63159  162932   -46   1578462603277   \n",
       "1  1578462618826  32835   65513   -49   1578462618272   \n",
       "\n",
       "                       path floor                      site  floorNo  \n",
       "0  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266       -1  \n",
       "1  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266       -1  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_wifi_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:02<00:00, 208.79it/s]\n"
     ]
    }
   ],
   "source": [
    "test_wifi_pd_csv = []\n",
    "for filename in tqdm(test_wifi_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    tmp['path'] = filename.split('/')[-1].replace('.txt','')\n",
    "    test_wifi_pd_csv.append(tmp)\n",
    "test_wifi_pd_csv = pd.concat(test_wifi_pd_csv).reset_index(drop=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 24/24 [00:07<00:00,  3.04it/s]\n"
     ]
    }
   ],
   "source": [
    "test_sensor_pd_csv = []\n",
    "for filename in tqdm(test_sensor_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    test_sensor_pd_csv.append(tmp)\n",
    "test_sensor_pd_csv = pd.concat(test_sensor_pd_csv).reset_index(drop=True)\n",
    "\n",
    "test_sensor_pd_csv['magne'] = test_sensor_pd_csv['x_magne']**2+\\\n",
    "                                test_sensor_pd_csv['y_magne']**2+test_sensor_pd_csv['z_magne']**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "standcols_sensor = ['x_acce', 'y_acce', 'z_acce', 'x_magne', 'y_magne',\n",
    "       'z_magne', 'x_gyros', 'y_gyros', 'z_gyros', 'x_rotate', 'y_rotate',\n",
    "       'z_rotate','magne']\n",
    "\n",
    "ss_sensor = StandardScaler()\n",
    "ss_sensor.fit(train_sensor_pd_csv.loc[:,standcols_sensor])\n",
    "train_sensor_pd_csv.loc[:,standcols_sensor] = ss_sensor.transform(train_sensor_pd_csv.loc[:,standcols_sensor])\n",
    "test_sensor_pd_csv.loc[:,standcols_sensor] = ss_sensor.transform(test_sensor_pd_csv.loc[:,standcols_sensor])        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>last_timestamp</th>\n",
       "      <th>path</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1961</td>\n",
       "      <td>70537</td>\n",
       "      <td>28318</td>\n",
       "      <td>-34</td>\n",
       "      <td>1571828560156</td>\n",
       "      <td>14f45baa63b4d3a700126af6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1961</td>\n",
       "      <td>43838</td>\n",
       "      <td>93116</td>\n",
       "      <td>-35</td>\n",
       "      <td>1571828560159</td>\n",
       "      <td>14f45baa63b4d3a700126af6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp   ssid  bssid  rssi  last_timestamp                      path\n",
       "0       1961  70537  28318   -34   1571828560156  14f45baa63b4d3a700126af6\n",
       "1       1961  43838  93116   -35   1571828560159  14f45baa63b4d3a700126af6"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "submission = pd.read_csv('submission_floor.csv')\n",
    "submission['path'] = [xx.split('_')[1] for xx in submission['site_path_timestamp']]\n",
    "test_path_floor_dict = dict(zip(submission.path,submission.floor))\n",
    "test_wifi_pd_csv['floorNo'] = [test_path_floor_dict[xx] for xx in test_wifi_pd_csv['path']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "standcols = ['rssi','floorNo']\n",
    "ss = StandardScaler()\n",
    "ss.fit(train_wifi_pd_csv.loc[:,standcols])\n",
    "train_wifi_pd_csv.loc[:,standcols] = ss.transform(train_wifi_pd_csv.loc[:,standcols])\n",
    "test_wifi_pd_csv.loc[:,standcols] = ss.transform(test_wifi_pd_csv.loc[:,standcols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>last_timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>floorNo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>63159</td>\n",
       "      <td>162932</td>\n",
       "      <td>3.105926</td>\n",
       "      <td>1578462603277</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1.340327</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>32835</td>\n",
       "      <td>65513</td>\n",
       "      <td>2.810727</td>\n",
       "      <td>1578462618272</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1.340327</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp   ssid   bssid      rssi  last_timestamp  \\\n",
       "0  1578462618826  63159  162932  3.105926   1578462603277   \n",
       "1  1578462618826  32835   65513  2.810727   1578462618272   \n",
       "\n",
       "                       path floor                      site   floorNo  \n",
       "0  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266 -1.340327  \n",
       "1  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266 -1.340327  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_wifi_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10877/10877 [02:51<00:00, 63.29it/s] \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.088208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.040317</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  wifi_median  wifi_std  \n",
       "0     0.206   0.353603     0.350737  1.088208  \n",
       "1     0.220   0.299748     0.350737  1.040317  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_wifi_pd = []\n",
    "for path,tmp in tqdm(train_wifi_pd_csv.groupby('path')):\n",
    "    tmp['ssid'] = tmp['ssid'].apply(lambda x: ssiddict[x])\n",
    "    tmp['bssid'] = tmp['bssid'].apply(lambda x: bssiddict[x])\n",
    "    ss1 = tmp.groupby('timestamp')['ssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[ssiddict['empty']]*(seqlen-len(x))) \n",
    "    ss2 = tmp.groupby('timestamp')['bssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[bssiddict['empty']]*(seqlen-len(x)))\n",
    "    ss3 = tmp.groupby('timestamp')['rssi'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[-10]*(seqlen-len(x)))\n",
    "    \n",
    "    ss = pd.concat([ss1,ss2,ss3],axis=1)\n",
    "    ss['path'] = tmp.path.unique()[0]\n",
    "    ss['floorNo'] = tmp.floorNo.unique()[0]\n",
    "    ss['floor'] = tmp.floor.unique()[0]\n",
    "    ss['site'] = tmp.site.unique()[0]\n",
    "    ss['wifi_len'] = tmp.groupby('timestamp')['rssi'].count()/500\n",
    "    ss['wifi_mean'] = tmp.groupby('timestamp')['rssi'].mean()\n",
    "    ss['wifi_median'] = tmp.groupby('timestamp')['rssi'].median()\n",
    "    ss['wifi_std'] = tmp.groupby('timestamp')['rssi'].std()\n",
    "\n",
    "    train_wifi_pd.append(ss)\n",
    "train_wifi_pd = pd.concat(train_wifi_pd)\n",
    "train_wifi_pd = train_wifi_pd.reset_index()\n",
    "train_wifi_pd.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:15<00:00, 41.61it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>site</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1180</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3048</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4924</td>\n",
       "      <td>[9522, 18669, 7007, 19396, 15215, 4851, 15215,...</td>\n",
       "      <td>[10783, 4531, 35106, 19211, 48757, 11767, 3933...</td>\n",
       "      <td>[1.4331324770012857, 1.2363332562632146, 1.039...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.048</td>\n",
       "      <td>-0.149461</td>\n",
       "      <td>-0.436460</td>\n",
       "      <td>0.815521</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6816</td>\n",
       "      <td>[18669, 4851, 15215, 7007, 9522, 19396, 19396,...</td>\n",
       "      <td>[4531, 11767, 39335, 35106, 10783, 19211, 5710...</td>\n",
       "      <td>[1.826730918477428, 1.1379336458941791, 1.0395...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.052</td>\n",
       "      <td>-0.118554</td>\n",
       "      <td>-0.534860</td>\n",
       "      <td>0.911802</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8693</td>\n",
       "      <td>[18669, 15215, 7007, 4851, 9522, 19396, 15215,...</td>\n",
       "      <td>[4531, 48757, 35106, 11767, 10783, 19211, 3933...</td>\n",
       "      <td>[2.1219297495845346, 1.3347328666322502, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.062</td>\n",
       "      <td>-0.182526</td>\n",
       "      <td>-0.534860</td>\n",
       "      <td>0.905339</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                                               ssid  \\\n",
       "0       1180  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1       3048  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "2       4924  [9522, 18669, 7007, 19396, 15215, 4851, 15215,...   \n",
       "3       6816  [18669, 4851, 15215, 7007, 9522, 19396, 19396,...   \n",
       "4       8693  [18669, 15215, 7007, 4851, 9522, 19396, 15215,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "2  [10783, 4531, 35106, 19211, 48757, 11767, 3933...   \n",
       "3  [4531, 11767, 39335, 35106, 10783, 19211, 5710...   \n",
       "4  [4531, 48757, 35106, 11767, 10783, 19211, 3933...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "2  [1.4331324770012857, 1.2363332562632146, 1.039...   \n",
       "3  [1.826730918477428, 1.1379336458941791, 1.0395...   \n",
       "4  [2.1219297495845346, 1.3347328666322502, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "2  00ff0c9a71cc37a2ebdd0f05  0.845957     0.048  -0.149461    -0.436460   \n",
       "3  00ff0c9a71cc37a2ebdd0f05  0.845957     0.052  -0.118554    -0.534860   \n",
       "4  00ff0c9a71cc37a2ebdd0f05  0.845957     0.062  -0.182526    -0.534860   \n",
       "\n",
       "   wifi_std                      site  \n",
       "0  1.033093  5da1389e4db8ce0c98bd0547  \n",
       "1  0.991529  5da1389e4db8ce0c98bd0547  \n",
       "2  0.815521  5da1389e4db8ce0c98bd0547  \n",
       "3  0.911802  5da1389e4db8ce0c98bd0547  \n",
       "4  0.905339  5da1389e4db8ce0c98bd0547  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd = []\n",
    "# for filename in tqdm(test_wifi_files):\n",
    "for path,tmp in tqdm(test_wifi_pd_csv.groupby('path')):\n",
    "    #tmp = pd.read_csv(filename)\n",
    "    #tmp['rssi'] = tmp['rssi']/999\n",
    "    tmp['ssid'] = tmp['ssid'].apply(lambda x: ssiddict[x])\n",
    "    tmp['bssid'] = tmp['bssid'].apply(lambda x: bssiddict[x])\n",
    "    ss1 = tmp.groupby('timestamp')['ssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[ssiddict['empty']]*(seqlen-len(x))) \n",
    "    ss2 = tmp.groupby('timestamp')['bssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[bssiddict['empty']]*(seqlen-len(x)))\n",
    "    ss3 = tmp.groupby('timestamp')['rssi'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[-10]*(seqlen-len(x)))\n",
    "    ss = pd.concat([ss1,ss2,ss3],axis=1)\n",
    "    #ss['path'] = filename.split('/')[-1].replace('.txt','')\n",
    "    ss['path'] = tmp.path.unique()[0]\n",
    "    ss['floorNo'] = tmp.floorNo.unique()[0]\n",
    "    ss['wifi_len'] = tmp.groupby('timestamp')['rssi'].count()/500\n",
    "    ss['wifi_mean'] = tmp.groupby('timestamp')['rssi'].mean()\n",
    "    ss['wifi_median'] = tmp.groupby('timestamp')['rssi'].median()\n",
    "    ss['wifi_std'] = tmp.groupby('timestamp')['rssi'].std()\n",
    "\n",
    "    test_wifi_pd.append(ss)\n",
    "test_wifi_pd = pd.concat(test_wifi_pd)\n",
    "test_wifi_pd = test_wifi_pd.reset_index()\n",
    "test_wifi_pd['site'] = [path2site[xx] for xx in test_wifi_pd.path]\n",
    "test_wifi_pd.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(37678, 11)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 204/204 [00:00<00:00, 266.90it/s]\n"
     ]
    }
   ],
   "source": [
    "# filename = train_files[0]\n",
    "train_xy = []\n",
    "for filename in tqdm(train_files):\n",
    "    tmp = pd.read_csv(filename,index_col=0)\n",
    "    ss = tmp[['path','site','floor','ts_waypoint','x','y']]\n",
    "    train_xy.append(ss)\n",
    "train_xy = pd.concat(train_xy).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(166681, 6)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_xy=train_xy.drop_duplicates()\n",
    "train_xy.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ts_sensor</th>\n",
       "      <th>x_acce</th>\n",
       "      <th>y_acce</th>\n",
       "      <th>z_acce</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>floor</th>\n",
       "      <th>floor_ori</th>\n",
       "      <th>magne</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.578463e+12</td>\n",
       "      <td>0.544694</td>\n",
       "      <td>2.170451</td>\n",
       "      <td>-0.215846</td>\n",
       "      <td>-0.061473</td>\n",
       "      <td>0.269189</td>\n",
       "      <td>0.06898</td>\n",
       "      <td>-0.624434</td>\n",
       "      <td>0.608398</td>\n",
       "      <td>-0.327102</td>\n",
       "      <td>3.204804</td>\n",
       "      <td>-1.107252</td>\n",
       "      <td>-0.362515</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "      <td>B1</td>\n",
       "      <td>-0.296067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.578463e+12</td>\n",
       "      <td>0.570460</td>\n",
       "      <td>2.231734</td>\n",
       "      <td>-0.209542</td>\n",
       "      <td>-0.070594</td>\n",
       "      <td>-0.024665</td>\n",
       "      <td>0.02118</td>\n",
       "      <td>-0.561460</td>\n",
       "      <td>0.608398</td>\n",
       "      <td>-0.170987</td>\n",
       "      <td>3.168680</td>\n",
       "      <td>-1.130263</td>\n",
       "      <td>-0.360321</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "      <td>B1</td>\n",
       "      <td>-0.298708</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      ts_sensor    x_acce    y_acce    z_acce   x_magne   y_magne  z_magne  \\\n",
       "0  1.578463e+12  0.544694  2.170451 -0.215846 -0.061473  0.269189  0.06898   \n",
       "1  1.578463e+12  0.570460  2.231734 -0.209542 -0.070594 -0.024665  0.02118   \n",
       "\n",
       "    x_gyros   y_gyros   z_gyros  x_rotate  y_rotate  z_rotate  \\\n",
       "0 -0.624434  0.608398 -0.327102  3.204804 -1.107252 -0.362515   \n",
       "1 -0.561460  0.608398 -0.170987  3.168680 -1.130263 -0.360321   \n",
       "\n",
       "                       path                      site  floor floor_ori  \\\n",
       "0  5e15730aa280850006f3d005  5a0546857ecc773753327266     -1        B1   \n",
       "1  5e15730aa280850006f3d005  5a0546857ecc773753327266     -1        B1   \n",
       "\n",
       "      magne  \n",
       "0 -0.296067  \n",
       "1 -0.298708  "
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_sensor_pd_csv.head(2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['ts_sensor', 'x_acce', 'y_acce', 'z_acce', 'x_magne', 'y_magne',\n",
       "       'z_magne', 'x_gyros', 'y_gyros', 'z_gyros', 'x_rotate', 'y_rotate',\n",
       "       'z_rotate', 'path', 'site', 'floor', 'floor_ori', 'magne'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_sensor_pd_csv.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sensor_pd_csv_group = dict(list(train_sensor_pd_csv.groupby('path',as_index=False)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_xy_group = dict(list(train_xy.groupby('path',as_index=False)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ts_sensor</th>\n",
       "      <th>x_acce</th>\n",
       "      <th>y_acce</th>\n",
       "      <th>z_acce</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>floor</th>\n",
       "      <th>floor_ori</th>\n",
       "      <th>magne</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1734783</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>-0.994056</td>\n",
       "      <td>-0.248305</td>\n",
       "      <td>0.177490</td>\n",
       "      <td>-0.370349</td>\n",
       "      <td>0.049532</td>\n",
       "      <td>0.332508</td>\n",
       "      <td>0.238923</td>\n",
       "      <td>-0.398348</td>\n",
       "      <td>2.083426</td>\n",
       "      <td>-1.588827</td>\n",
       "      <td>0.498332</td>\n",
       "      <td>1.416640</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.265570</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1734784</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>-0.377300</td>\n",
       "      <td>-0.036921</td>\n",
       "      <td>0.058472</td>\n",
       "      <td>-1.068294</td>\n",
       "      <td>-3.462430</td>\n",
       "      <td>0.174510</td>\n",
       "      <td>0.238923</td>\n",
       "      <td>-0.398348</td>\n",
       "      <td>2.083426</td>\n",
       "      <td>-1.501465</td>\n",
       "      <td>0.381565</td>\n",
       "      <td>1.418280</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>0.310682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1734785</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>0.072799</td>\n",
       "      <td>-0.174701</td>\n",
       "      <td>0.075813</td>\n",
       "      <td>-0.658117</td>\n",
       "      <td>-2.519185</td>\n",
       "      <td>0.323945</td>\n",
       "      <td>0.145778</td>\n",
       "      <td>-0.431256</td>\n",
       "      <td>2.314335</td>\n",
       "      <td>-1.306327</td>\n",
       "      <td>0.301993</td>\n",
       "      <td>1.418707</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.000635</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1734786</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>-1.088537</td>\n",
       "      <td>0.034150</td>\n",
       "      <td>0.154237</td>\n",
       "      <td>-0.209088</td>\n",
       "      <td>-1.044488</td>\n",
       "      <td>0.316575</td>\n",
       "      <td>0.269925</td>\n",
       "      <td>-0.464091</td>\n",
       "      <td>2.468334</td>\n",
       "      <td>-1.309413</td>\n",
       "      <td>0.313158</td>\n",
       "      <td>1.419393</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.247187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1734787</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>-0.625255</td>\n",
       "      <td>-0.065917</td>\n",
       "      <td>-0.063504</td>\n",
       "      <td>0.042315</td>\n",
       "      <td>-0.296481</td>\n",
       "      <td>0.697602</td>\n",
       "      <td>0.176849</td>\n",
       "      <td>-0.431256</td>\n",
       "      <td>2.160513</td>\n",
       "      <td>-1.345893</td>\n",
       "      <td>0.349869</td>\n",
       "      <td>1.420749</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.255654</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1735679</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>0.562417</td>\n",
       "      <td>0.747711</td>\n",
       "      <td>-0.869028</td>\n",
       "      <td>-0.123092</td>\n",
       "      <td>-0.364089</td>\n",
       "      <td>0.194093</td>\n",
       "      <td>0.642367</td>\n",
       "      <td>-2.370723</td>\n",
       "      <td>4.084173</td>\n",
       "      <td>-0.705909</td>\n",
       "      <td>1.252199</td>\n",
       "      <td>1.423396</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.289443</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1735680</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>0.600212</td>\n",
       "      <td>0.790128</td>\n",
       "      <td>-0.761443</td>\n",
       "      <td>-0.447272</td>\n",
       "      <td>-0.172782</td>\n",
       "      <td>0.276146</td>\n",
       "      <td>0.642367</td>\n",
       "      <td>-2.436466</td>\n",
       "      <td>4.238171</td>\n",
       "      <td>-0.702471</td>\n",
       "      <td>1.211102</td>\n",
       "      <td>1.424129</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.256173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1735681</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>0.679803</td>\n",
       "      <td>0.779619</td>\n",
       "      <td>-0.564002</td>\n",
       "      <td>-0.684603</td>\n",
       "      <td>-0.257777</td>\n",
       "      <td>0.309240</td>\n",
       "      <td>0.611365</td>\n",
       "      <td>-2.436466</td>\n",
       "      <td>4.161084</td>\n",
       "      <td>-0.699489</td>\n",
       "      <td>1.148888</td>\n",
       "      <td>1.425078</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.204516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1735682</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>0.668358</td>\n",
       "      <td>0.855755</td>\n",
       "      <td>-0.537991</td>\n",
       "      <td>-0.694506</td>\n",
       "      <td>-0.356393</td>\n",
       "      <td>0.267582</td>\n",
       "      <td>0.580293</td>\n",
       "      <td>-2.304979</td>\n",
       "      <td>4.084173</td>\n",
       "      <td>-0.692203</td>\n",
       "      <td>1.084712</td>\n",
       "      <td>1.425973</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.202040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1735683</th>\n",
       "      <td>1.560501e+12</td>\n",
       "      <td>0.814371</td>\n",
       "      <td>0.855755</td>\n",
       "      <td>-0.516908</td>\n",
       "      <td>-0.625874</td>\n",
       "      <td>-0.215307</td>\n",
       "      <td>0.254106</td>\n",
       "      <td>0.611365</td>\n",
       "      <td>-2.370723</td>\n",
       "      <td>3.930350</td>\n",
       "      <td>-0.689897</td>\n",
       "      <td>1.028067</td>\n",
       "      <td>1.426764</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>2</td>\n",
       "      <td>F3</td>\n",
       "      <td>-0.221778</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>901 rows × 18 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            ts_sensor    x_acce    y_acce    z_acce   x_magne   y_magne  \\\n",
       "1734783  1.560501e+12 -0.994056 -0.248305  0.177490 -0.370349  0.049532   \n",
       "1734784  1.560501e+12 -0.377300 -0.036921  0.058472 -1.068294 -3.462430   \n",
       "1734785  1.560501e+12  0.072799 -0.174701  0.075813 -0.658117 -2.519185   \n",
       "1734786  1.560501e+12 -1.088537  0.034150  0.154237 -0.209088 -1.044488   \n",
       "1734787  1.560501e+12 -0.625255 -0.065917 -0.063504  0.042315 -0.296481   \n",
       "...               ...       ...       ...       ...       ...       ...   \n",
       "1735679  1.560501e+12  0.562417  0.747711 -0.869028 -0.123092 -0.364089   \n",
       "1735680  1.560501e+12  0.600212  0.790128 -0.761443 -0.447272 -0.172782   \n",
       "1735681  1.560501e+12  0.679803  0.779619 -0.564002 -0.684603 -0.257777   \n",
       "1735682  1.560501e+12  0.668358  0.855755 -0.537991 -0.694506 -0.356393   \n",
       "1735683  1.560501e+12  0.814371  0.855755 -0.516908 -0.625874 -0.215307   \n",
       "\n",
       "          z_magne   x_gyros   y_gyros   z_gyros  x_rotate  y_rotate  z_rotate  \\\n",
       "1734783  0.332508  0.238923 -0.398348  2.083426 -1.588827  0.498332  1.416640   \n",
       "1734784  0.174510  0.238923 -0.398348  2.083426 -1.501465  0.381565  1.418280   \n",
       "1734785  0.323945  0.145778 -0.431256  2.314335 -1.306327  0.301993  1.418707   \n",
       "1734786  0.316575  0.269925 -0.464091  2.468334 -1.309413  0.313158  1.419393   \n",
       "1734787  0.697602  0.176849 -0.431256  2.160513 -1.345893  0.349869  1.420749   \n",
       "...           ...       ...       ...       ...       ...       ...       ...   \n",
       "1735679  0.194093  0.642367 -2.370723  4.084173 -0.705909  1.252199  1.423396   \n",
       "1735680  0.276146  0.642367 -2.436466  4.238171 -0.702471  1.211102  1.424129   \n",
       "1735681  0.309240  0.611365 -2.436466  4.161084 -0.699489  1.148888  1.425078   \n",
       "1735682  0.267582  0.580293 -2.304979  4.084173 -0.692203  1.084712  1.425973   \n",
       "1735683  0.254106  0.611365 -2.370723  3.930350 -0.689897  1.028067  1.426764   \n",
       "\n",
       "                             path                      site  floor floor_ori  \\\n",
       "1734783  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1734784  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1734785  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1734786  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1734787  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "...                           ...                       ...    ...       ...   \n",
       "1735679  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1735680  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1735681  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1735682  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "1735683  5d073b814a19c000086c558b  5c3c44b80379370013e0fd2b      2        F3   \n",
       "\n",
       "            magne  \n",
       "1734783 -0.265570  \n",
       "1734784  0.310682  \n",
       "1734785 -0.000635  \n",
       "1734786 -0.247187  \n",
       "1734787 -0.255654  \n",
       "...           ...  \n",
       "1735679 -0.289443  \n",
       "1735680 -0.256173  \n",
       "1735681 -0.204516  \n",
       "1735682 -0.202040  \n",
       "1735683 -0.221778  \n",
       "\n",
       "[901 rows x 18 columns]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_sensor_pd_csv_group['5d073b814a19c000086c558b']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10877/10877 [04:10<00:00, 43.34it/s]\n"
     ]
    }
   ],
   "source": [
    "import scipy.stats as stats\n",
    "import scipy\n",
    "train_all = []\n",
    "\n",
    "for path,train_wifi_pd_x in tqdm(train_wifi_pd.groupby('path')):\n",
    "    # path = '5e15730aa280850006f3d005'\n",
    "    train_y = train_xy_group[path][['path','ts_waypoint','x','y']].drop_duplicates().reset_index(drop=True)\n",
    "    train_sensor = train_sensor_pd_csv_group[path][['ts_sensor', 'x_acce', 'y_acce', 'z_acce', \n",
    "       'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "       'z_gyros','x_rotate', 'y_rotate', 'z_rotate', 'path']].reset_index(drop=True)\n",
    "\n",
    "    train_wifi_pd_x['ts_waypoint'] = 0\n",
    "    if len(train_y)==0:\n",
    "        print(path,'have no waypoint')\n",
    "    if len(train_y)>0:\n",
    "        ts_point_min = train_y.ts_waypoint.min()\n",
    "        ts_point_max = train_y.ts_waypoint.max()\n",
    "        tmp2 = train_wifi_pd_x[['timestamp']].drop_duplicates()\n",
    "        tmp2 = tmp2[(tmp2.timestamp<=ts_point_max)&(tmp2.timestamp>=ts_point_min)]\n",
    "        \n",
    "        ts_sensor_min = train_sensor.ts_sensor.min()\n",
    "        ts_sensor_max = train_sensor.ts_sensor.max()\n",
    "        tmp3 = train_wifi_pd_x[['timestamp']].drop_duplicates()\n",
    "        tmp3 = tmp3[(tmp3.timestamp<=ts_sensor_max)&(tmp3.timestamp>=ts_sensor_min)]\n",
    "        \n",
    "        if len(tmp2)>0:\n",
    "            T_rel = train_y['ts_waypoint']\n",
    "            T_rel2 = train_sensor['ts_sensor']\n",
    "            T_ref = tmp2['timestamp']\n",
    "            T_ref2 = tmp3['timestamp']\n",
    "            xy_hat = scipy.interpolate.interp1d(T_rel, train_y[['x','y']], axis=0)(T_ref)\n",
    "            sensor_hat = scipy.interpolate.interp1d(T_rel2, train_sensor[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "                                                   'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "                                                   'z_gyros','x_rotate', 'y_rotate', 'z_rotate']], axis=0)(T_ref2)\n",
    "            tmp2[['x','y']] = xy_hat\n",
    "            tmp3[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "               'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "               'z_gyros','x_rotate', 'y_rotate', 'z_rotate']] = sensor_hat\n",
    "            tmp2['path'] = path\n",
    "            tmp3['path'] = path\n",
    "            train_wifi_pd_x = pd.merge(train_wifi_pd_x,tmp2,how='left',on=['path','timestamp'])\n",
    "            train_wifi_pd_x = pd.merge(train_wifi_pd_x,tmp3,how='left',on=['path','timestamp'])\n",
    "            train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "               'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "               'z_gyros','x_rotate', 'y_rotate', 'z_rotate']] = train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "               'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "               'z_gyros','x_rotate', 'y_rotate', 'z_rotate']].fillna(method='ffill')\n",
    "            train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "               'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "               'z_gyros','x_rotate', 'y_rotate', 'z_rotate']] = train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "               'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "               'z_gyros','x_rotate', 'y_rotate', 'z_rotate']].fillna(method='bfill')\n",
    "            train_all.append(train_wifi_pd_x)\n",
    "    \n",
    "train_all = pd.concat(train_all).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(258097, 27)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((11756, 27), (11756, 27))"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all[train_all.x.isna()].shape,train_all[train_all.y.isna()].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_all = train_all[~train_all.x.isna()].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>...</th>\n",
       "      <th>z_acce</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>...</td>\n",
       "      <td>0.461806</td>\n",
       "      <td>0.634199</td>\n",
       "      <td>0.116776</td>\n",
       "      <td>-0.110190</td>\n",
       "      <td>0.983807</td>\n",
       "      <td>-0.595578</td>\n",
       "      <td>0.852319</td>\n",
       "      <td>-0.630592</td>\n",
       "      <td>0.850756</td>\n",
       "      <td>1.353243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.246482</td>\n",
       "      <td>1.202698</td>\n",
       "      <td>-0.395429</td>\n",
       "      <td>-0.167547</td>\n",
       "      <td>1.256958</td>\n",
       "      <td>-0.496999</td>\n",
       "      <td>1.898839</td>\n",
       "      <td>-0.977061</td>\n",
       "      <td>0.819255</td>\n",
       "      <td>1.276234</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1560501001590</td>\n",
       "      <td>[18304, 19396, 7702, 7702, 19396, 7702, 12721,...</td>\n",
       "      <td>[10121, 57287, 31140, 61027, 55262, 22353, 603...</td>\n",
       "      <td>[3.1059258532748903, 3.1059258532748903, 2.810...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.238</td>\n",
       "      <td>0.268875</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.704095</td>\n",
       "      <td>-0.460087</td>\n",
       "      <td>0.171733</td>\n",
       "      <td>-0.242435</td>\n",
       "      <td>1.697688</td>\n",
       "      <td>-0.036796</td>\n",
       "      <td>1.083052</td>\n",
       "      <td>-0.492361</td>\n",
       "      <td>1.059535</td>\n",
       "      <td>1.177969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1560501003516</td>\n",
       "      <td>[19396, 7702, 19396, 18304, 7702, 7702, 7702, ...</td>\n",
       "      <td>[57287, 31140, 55262, 10121, 22353, 53865, 432...</td>\n",
       "      <td>[3.1059258532748903, 2.8107270221677836, 2.613...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.258</td>\n",
       "      <td>0.230216</td>\n",
       "      <td>...</td>\n",
       "      <td>0.698048</td>\n",
       "      <td>0.366486</td>\n",
       "      <td>-0.610369</td>\n",
       "      <td>-0.043591</td>\n",
       "      <td>1.107954</td>\n",
       "      <td>0.193343</td>\n",
       "      <td>0.313588</td>\n",
       "      <td>-0.572031</td>\n",
       "      <td>0.860852</td>\n",
       "      <td>1.132166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1560501005442</td>\n",
       "      <td>[7702, 18304, 19396, 19396, 7702, 7702, 7702, ...</td>\n",
       "      <td>[31140, 10121, 55262, 57287, 43265, 61027, 612...</td>\n",
       "      <td>[2.8107270221677836, 2.6139278014297127, 2.613...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.282</td>\n",
       "      <td>0.210465</td>\n",
       "      <td>...</td>\n",
       "      <td>0.136020</td>\n",
       "      <td>-0.585768</td>\n",
       "      <td>-0.707650</td>\n",
       "      <td>0.853478</td>\n",
       "      <td>1.480396</td>\n",
       "      <td>-0.201118</td>\n",
       "      <td>0.544498</td>\n",
       "      <td>-0.470753</td>\n",
       "      <td>0.657864</td>\n",
       "      <td>1.078007</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "2  1560501001590  [18304, 19396, 7702, 7702, 19396, 7702, 12721,...   \n",
       "3  1560501003516  [19396, 7702, 19396, 18304, 7702, 7702, 7702, ...   \n",
       "4  1560501005442  [7702, 18304, 19396, 19396, 7702, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "2  [10121, 57287, 31140, 61027, 55262, 22353, 603...   \n",
       "3  [57287, 31140, 55262, 10121, 22353, 53865, 432...   \n",
       "4  [31140, 10121, 55262, 57287, 43265, 61027, 612...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "2  [3.1059258532748903, 3.1059258532748903, 2.810...   \n",
       "3  [3.1059258532748903, 2.8107270221677836, 2.613...   \n",
       "4  [2.8107270221677836, 2.6139278014297127, 2.613...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "2  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "3  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "4  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  ...    z_acce   x_magne   y_magne   z_magne   x_gyros  \\\n",
       "0     0.206   0.353603  ...  0.461806  0.634199  0.116776 -0.110190  0.983807   \n",
       "1     0.220   0.299748  ... -0.246482  1.202698 -0.395429 -0.167547  1.256958   \n",
       "2     0.238   0.268875  ... -0.704095 -0.460087  0.171733 -0.242435  1.697688   \n",
       "3     0.258   0.230216  ...  0.698048  0.366486 -0.610369 -0.043591  1.107954   \n",
       "4     0.282   0.210465  ...  0.136020 -0.585768 -0.707650  0.853478  1.480396   \n",
       "\n",
       "    y_gyros   z_gyros  x_rotate  y_rotate  z_rotate  \n",
       "0 -0.595578  0.852319 -0.630592  0.850756  1.353243  \n",
       "1 -0.496999  1.898839 -0.977061  0.819255  1.276234  \n",
       "2 -0.036796  1.083052 -0.492361  1.059535  1.177969  \n",
       "3  0.193343  0.313588 -0.572031  0.860852  1.132166  \n",
       "4 -0.201118  0.544498 -0.470753  0.657864  1.078007  \n",
       "\n",
       "[5 rows x 27 columns]"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(246341, 27)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "N_SPLITS = 10\n",
    "SEED = 0#42\n",
    "for fold, (trn_idx, val_idx) in enumerate(StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED).split(train_all['site'], train_all['site'])):\n",
    "    train_all.loc[val_idx, 'fold'] = fold\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all[train_all.path=='5dd3824044333f00067aa2c4'].fold.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all[train_all.site=='5c3c44b80379370013e0fd2b'].fold.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>...</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>fold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>...</td>\n",
       "      <td>0.634199</td>\n",
       "      <td>0.116776</td>\n",
       "      <td>-0.110190</td>\n",
       "      <td>0.983807</td>\n",
       "      <td>-0.595578</td>\n",
       "      <td>0.852319</td>\n",
       "      <td>-0.630592</td>\n",
       "      <td>0.850756</td>\n",
       "      <td>1.353243</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>...</td>\n",
       "      <td>1.202698</td>\n",
       "      <td>-0.395429</td>\n",
       "      <td>-0.167547</td>\n",
       "      <td>1.256958</td>\n",
       "      <td>-0.496999</td>\n",
       "      <td>1.898839</td>\n",
       "      <td>-0.977061</td>\n",
       "      <td>0.819255</td>\n",
       "      <td>1.276234</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 28 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  ...   x_magne   y_magne   z_magne   x_gyros   y_gyros  \\\n",
       "0     0.206   0.353603  ...  0.634199  0.116776 -0.110190  0.983807 -0.595578   \n",
       "1     0.220   0.299748  ...  1.202698 -0.395429 -0.167547  1.256958 -0.496999   \n",
       "\n",
       "    z_gyros  x_rotate  y_rotate  z_rotate  fold  \n",
       "0  0.852319 -0.630592  0.850756  1.353243   6.0  \n",
       "1  1.898839 -0.977061  0.819255  1.276234   8.0  \n",
       "\n",
       "[2 rows x 28 columns]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>...</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>fold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>...</td>\n",
       "      <td>0.634199</td>\n",
       "      <td>0.116776</td>\n",
       "      <td>-0.110190</td>\n",
       "      <td>0.983807</td>\n",
       "      <td>-0.595578</td>\n",
       "      <td>0.852319</td>\n",
       "      <td>-0.630592</td>\n",
       "      <td>0.850756</td>\n",
       "      <td>1.353243</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>...</td>\n",
       "      <td>1.202698</td>\n",
       "      <td>-0.395429</td>\n",
       "      <td>-0.167547</td>\n",
       "      <td>1.256958</td>\n",
       "      <td>-0.496999</td>\n",
       "      <td>1.898839</td>\n",
       "      <td>-0.977061</td>\n",
       "      <td>0.819255</td>\n",
       "      <td>1.276234</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 28 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  ...   x_magne   y_magne   z_magne   x_gyros   y_gyros  \\\n",
       "0     0.206   0.353603  ...  0.634199  0.116776 -0.110190  0.983807 -0.595578   \n",
       "1     0.220   0.299748  ...  1.202698 -0.395429 -0.167547  1.256958 -0.496999   \n",
       "\n",
       "    z_gyros  x_rotate  y_rotate  z_rotate  fold  \n",
       "0  0.852319 -0.630592  0.850756  1.353243   6.0  \n",
       "1  1.898839 -0.977061  0.819255  1.276234   8.0  \n",
       "\n",
       "[2 rows x 28 columns]"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>site</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1180</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3048</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                                               ssid  \\\n",
       "0       1180  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1       3048  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "\n",
       "   wifi_std                      site  \n",
       "0  1.033093  5da1389e4db8ce0c98bd0547  \n",
       "1  0.991529  5da1389e4db8ce0c98bd0547  "
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sensor_pd_csv_group = dict(list(test_sensor_pd_csv.groupby('path',as_index=False)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:09<00:00, 64.84it/s]\n"
     ]
    }
   ],
   "source": [
    "import scipy.stats as stats\n",
    "import scipy\n",
    "test_all = []\n",
    "\n",
    "for path,train_wifi_pd_x in tqdm(test_wifi_pd.groupby('path')):\n",
    "    # path = '5e15730aa280850006f3d005'\n",
    "    train_sensor = test_sensor_pd_csv_group[path][['ts_sensor', 'x_acce', 'y_acce', 'z_acce', \n",
    "       'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "       'z_gyros','x_rotate', 'y_rotate', 'z_rotate', 'path']].reset_index(drop=True)\n",
    "\n",
    "    train_wifi_pd_x['ts_waypoint'] = 0\n",
    "\n",
    "    ts_point_min = train_sensor.ts_sensor.min()\n",
    "    ts_point_max = train_sensor.ts_sensor.max()\n",
    "    tmp2 = train_wifi_pd_x[['timestamp']].drop_duplicates()\n",
    "    tmp2 = tmp2[(tmp2.timestamp<=ts_point_max)&(tmp2.timestamp>=ts_point_min)]\n",
    "    if len(tmp2)>0:\n",
    "        T_rel2 = train_sensor['ts_sensor']\n",
    "        T_ref = tmp2['timestamp']\n",
    "        sensor_hat = scipy.interpolate.interp1d(T_rel2, train_sensor[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "                                               'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "                                               'z_gyros','x_rotate', 'y_rotate', 'z_rotate']], axis=0)(T_ref)\n",
    "        tmp2[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "           'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "           'z_gyros','x_rotate', 'y_rotate', 'z_rotate']] = sensor_hat\n",
    "        tmp2['path'] = path\n",
    "        train_wifi_pd_x = pd.merge(train_wifi_pd_x,tmp2,how='left',on=['path','timestamp'])\n",
    "        train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "           'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "           'z_gyros','x_rotate', 'y_rotate', 'z_rotate']] = train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "           'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "           'z_gyros','x_rotate', 'y_rotate', 'z_rotate']].fillna(method='ffill')\n",
    "        train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "           'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "           'z_gyros','x_rotate', 'y_rotate', 'z_rotate']] = train_wifi_pd_x[[ 'x_acce', 'y_acce', 'z_acce', \n",
    "           'x_magne','y_magne', 'z_magne', 'x_gyros', 'y_gyros', \n",
    "           'z_gyros','x_rotate', 'y_rotate', 'z_rotate']].fillna(method='bfill')\n",
    "            \n",
    "        test_all.append(train_wifi_pd_x)\n",
    "    \n",
    "test_all = pd.concat(test_all).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from dask.distributed import wait\n",
    "\n",
    "SENSORS = ['acce','acce_uncali','gyro',\n",
    "           'gyro_uncali','magn','magn_uncali','ahrs']\n",
    "\n",
    "NFEAS = {\n",
    "    'acce': 3,\n",
    "    'acce_uncali': 3,\n",
    "    'gyro': 3,\n",
    "    'gyro_uncali': 3,\n",
    "    'magn': 3,\n",
    "    'magn_uncali': 3,\n",
    "    'ahrs': 3,\n",
    "    'wifi': 1,\n",
    "    'ibeacon': 1,\n",
    "    'waypoint': 3\n",
    "}\n",
    "\n",
    "ACOLS = ['timestamp','x','y','z']\n",
    "        \n",
    "FIELDS = {\n",
    "    'acce': ACOLS,\n",
    "    'acce_uncali': ACOLS,\n",
    "    'gyro': ACOLS,\n",
    "    'gyro_uncali': ACOLS,\n",
    "    'magn': ACOLS,\n",
    "    'magn_uncali': ACOLS,\n",
    "    'ahrs': ACOLS,\n",
    "    'wifi': ['timestamp','ssid','bssid','rssi','last_timestamp'],\n",
    "    'ibeacon': ['timestamp','code','rssi','last_timestamp'],\n",
    "    'waypoint': ['timestamp','x','y']\n",
    "}\n",
    "\n",
    "def to_frame(data, col):\n",
    "    cols = FIELDS[col]\n",
    "    is_dummy = False\n",
    "    if data.shape[0]>0:\n",
    "        df = pd.DataFrame(data, columns=cols)\n",
    "    else:\n",
    "        df = create_dummy_df(cols)\n",
    "        is_dummy = True\n",
    "    for col in df.columns:\n",
    "        if 'timestamp' in col:\n",
    "            df[col] = df[col].astype('int64')\n",
    "    return df, is_dummy\n",
    "\n",
    "def create_dummy_df(cols):\n",
    "    df = pd.DataFrame()\n",
    "    for col in cols:\n",
    "        df[col] = [0]\n",
    "        if col in ['ssid','bssid']:\n",
    "            df[col] = df[col].map(str)\n",
    "    return df\n",
    "\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ReadData:\n",
    "    acce: np.ndarray\n",
    "    acce_uncali: np.ndarray\n",
    "    gyro: np.ndarray\n",
    "    gyro_uncali: np.ndarray\n",
    "    magn: np.ndarray\n",
    "    magn_uncali: np.ndarray\n",
    "    ahrs: np.ndarray\n",
    "    wifi: np.ndarray\n",
    "    ibeacon: np.ndarray\n",
    "    waypoint: np.ndarray\n",
    "\n",
    "\n",
    "def read_data_file(data_filename):\n",
    "    acce = []\n",
    "    acce_uncali = []\n",
    "    gyro = []\n",
    "    gyro_uncali = []\n",
    "    magn = []\n",
    "    magn_uncali = []\n",
    "    ahrs = []\n",
    "    wifi = []\n",
    "    ibeacon = []\n",
    "    waypoint = []\n",
    "\n",
    "    with open(data_filename, 'r', encoding='utf-8') as file:\n",
    "        lines = file.readlines()\n",
    "\n",
    "    for line_data in lines:\n",
    "        line_data = line_data.strip()\n",
    "        if not line_data or line_data[0] == '#':\n",
    "            continue\n",
    "\n",
    "        line_data = line_data.split('\\t')\n",
    "\n",
    "        if line_data[1] == 'TYPE_ACCELEROMETER':\n",
    "            acce.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_ACCELEROMETER_UNCALIBRATED':\n",
    "            acce_uncali.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_GYROSCOPE':\n",
    "            gyro.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_GYROSCOPE_UNCALIBRATED':\n",
    "            gyro_uncali.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_MAGNETIC_FIELD':\n",
    "            magn.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_MAGNETIC_FIELD_UNCALIBRATED':\n",
    "            magn_uncali.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_ROTATION_VECTOR':\n",
    "            if len(line_data)>=5:\n",
    "                ahrs.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_WIFI':\n",
    "            sys_ts = line_data[0]\n",
    "            ssid = line_data[2]\n",
    "            bssid = line_data[3]\n",
    "            rssi = line_data[4]\n",
    "            lastseen_ts = line_data[6]\n",
    "            wifi_data = [sys_ts, ssid, bssid, rssi, lastseen_ts]\n",
    "            wifi.append(wifi_data)\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_BEACON':\n",
    "            ts = line_data[0]\n",
    "            uuid = line_data[2]\n",
    "            major = line_data[3]\n",
    "            minor = line_data[4]\n",
    "            rssi = line_data[6]\n",
    "            lastts = line_data[-1]\n",
    "            ibeacon_data = [ts, '_'.join([uuid, major, minor]), rssi, lastts]\n",
    "            ibeacon.append(ibeacon_data)\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_WAYPOINT':\n",
    "            waypoint.append([int(line_data[0]), float(line_data[2]), float(line_data[3])])\n",
    "\n",
    "    acce = np.array(acce)\n",
    "    acce_uncali = np.array(acce_uncali)\n",
    "    gyro = np.array(gyro)\n",
    "    gyro_uncali = np.array(gyro_uncali)\n",
    "    magn = np.array(magn)\n",
    "    magn_uncali = np.array(magn_uncali)\n",
    "    ahrs = np.array(ahrs)\n",
    "    wifi = np.array(wifi)\n",
    "    ibeacon = np.array(ibeacon)\n",
    "    waypoint = np.array(waypoint)\n",
    "\n",
    "    return ReadData(acce, acce_uncali, gyro, gyro_uncali, magn, magn_uncali, ahrs, wifi, ibeacon, waypoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_dfs(PATH, test_files):\n",
    "    dtest = get_test_df(PATH)\n",
    "    buildings = set(dtest['building'].values.tolist())\n",
    "    dws = {}\n",
    "    ntest_files = []\n",
    "    for fname in tqdm(test_files):\n",
    "        path = fname.split('/')[-1].split('.')[0]\n",
    "        mask = dtest['path'] == path\n",
    "        dws[fname] = dtest.loc[mask, ['timestamp','x','y','floor','building','site_path_timestamp']].copy().reset_index(drop=True)\n",
    "        ntest_files.append(fname)\n",
    "    return dws\n",
    "\n",
    "def get_test_df(PATH):\n",
    "    dtest = pd.read_csv(f'{PATH}/sample_submission.csv')\n",
    "    dtest['building'] = dtest['site_path_timestamp'].apply(lambda x: x.split('_')[0])\n",
    "    dtest['path'] = dtest['site_path_timestamp'].apply(lambda x: x.split('_')[1])\n",
    "    dtest['timestamp'] = dtest['site_path_timestamp'].apply(lambda x: x.split('_')[2])\n",
    "    dtest['timestamp'] = dtest['timestamp'].astype('int64')\n",
    "    dtest = dtest.sort_values(['path','timestamp']).reset_index(drop=True)\n",
    "    return dtest\n",
    "\n",
    "def get_time_gap(name):\n",
    "    data = read_data_file(name)\n",
    "    db,no_ibeacon = to_frame(data.ibeacon,'ibeacon')\n",
    "#     print(db,no_ibeacon)\n",
    "    \n",
    "    if no_ibeacon==0:\n",
    "        gap = db['last_timestamp'] - db['timestamp']\n",
    "        assert gap.unique().shape[0]==1\n",
    "        return gap.values[0],no_ibeacon\n",
    "    \n",
    "    if no_ibeacon==1:\n",
    "        # Group wifis by timestamp\n",
    "        wifi_groups = pd.DataFrame(data.wifi).groupby(0)   \n",
    "        # Find which one is the most recent of all time points.\n",
    "        est_ts = (wifi_groups[4].max().astype(int) - wifi_groups[0].max().astype(int)).max() \n",
    "        return est_ts,no_ibeacon\n",
    "\n",
    "    \n",
    "\n",
    "def fix_timestamp_test(df, gap):\n",
    "    df['real_timestamp'] = df['timestamp'] + gap\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../input/indoor-location-navigation/test/00ff0c9a71cc37a2ebdd0f05.txt',\n",
       " '../input/indoor-location-navigation/test/01c41f1aeba5c48c2c4dd568.txt',\n",
       " '../input/indoor-location-navigation/test/030b3d94de8acae7c936563d.txt',\n",
       " '../input/indoor-location-navigation/test/0389421238a7e2839701df0f.txt']"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_files_ori = glob.glob('../input/indoor-location-navigation/test/*.txt')\n",
    "test_files_ori[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ec2-user/anaconda3/lib/python3.7/site-packages/distributed/dashboard/core.py:79: UserWarning: \n",
      "Port 8787 is already in use. \n",
      "Perhaps you already have a cluster running?\n",
      "Hosting the diagnostics dashboard on a random port instead.\n",
      "  warnings.warn(\"\\n\" + msg)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border: 2px solid white;\">\n",
       "<tr>\n",
       "<td style=\"vertical-align: top; border: 0px solid white\">\n",
       "<h3 style=\"text-align: left;\">Client</h3>\n",
       "<ul style=\"text-align: left; list-style: none; margin: 0; padding: 0;\">\n",
       "  <li><b>Scheduler: </b>tcp://127.0.0.1:46533</li>\n",
       "  <li><b>Dashboard: </b><a href='http://127.0.0.1:39225/status' target='_blank'>http://127.0.0.1:39225/status</a>\n",
       "</ul>\n",
       "</td>\n",
       "<td style=\"vertical-align: top; border: 0px solid white\">\n",
       "<h3 style=\"text-align: left;\">Cluster</h3>\n",
       "<ul style=\"text-align: left; list-style:none; margin: 0; padding: 0;\">\n",
       "  <li><b>Workers: </b>8</li>\n",
       "  <li><b>Cores: </b>8</li>\n",
       "  <li><b>Memory: </b>66.71 GB</li>\n",
       "</ul>\n",
       "</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<Client: 'tcp://127.0.0.1:46533' processes=8 threads=8, memory=66.71 GB>"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import dask\n",
    "from dask.distributed import Client, wait, LocalCluster\n",
    "\n",
    "# set n_workers to number of cores\n",
    "client = Client(n_workers=8, \n",
    "                threads_per_worker=1)\n",
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:00<00:00, 9976.27it/s]\n",
      "100%|██████████| 626/626 [00:18<00:00, 33.98it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.8 s, sys: 194 ms, total: 3 s\n",
      "Wall time: 18.5 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "futures = []\n",
    "for fname in tqdm(test_files_ori, total=len(test_files_ori)):\n",
    "    f = client.submit(get_time_gap,fname)\n",
    "    futures.append(f)\n",
    "    \n",
    "testpath2gap = {}\n",
    "for f,fname in tqdm(zip(futures, test_files_ori), total=len(test_files_ori)):\n",
    "    testpath2gap[fname.split('/')[-1].replace('.txt','')] = f.result()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_all['timestamp'] = [xx+testpath2gap[yy][0] for (xx,yy) in zip(test_all['timestamp'],test_all['path'])]\n",
    "# test_all['ts_waypoint'] = [xx+testpath2gap[yy][0] for (xx,yy) in zip(test_all['ts_waypoint'],test_all['path'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_all['timestamp'] = (test_all['timestamp']-train_all_timestamp_min)/(train_all_timestamp_max-train_all_timestamp_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>...</th>\n",
       "      <th>z_acce</th>\n",
       "      <th>x_magne</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1573190312033</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>...</td>\n",
       "      <td>0.109450</td>\n",
       "      <td>-0.465064</td>\n",
       "      <td>-0.372143</td>\n",
       "      <td>0.303976</td>\n",
       "      <td>1.588786</td>\n",
       "      <td>0.558175</td>\n",
       "      <td>0.87003</td>\n",
       "      <td>0.073650</td>\n",
       "      <td>1.198900</td>\n",
       "      <td>0.804293</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1573190313901</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.080417</td>\n",
       "      <td>-0.153359</td>\n",
       "      <td>1.292601</td>\n",
       "      <td>0.500978</td>\n",
       "      <td>1.588793</td>\n",
       "      <td>0.248723</td>\n",
       "      <td>0.80081</td>\n",
       "      <td>0.203768</td>\n",
       "      <td>1.516082</td>\n",
       "      <td>0.877682</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1573190312033  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1  1573190313901  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "\n",
       "   wifi_std  ...    z_acce   x_magne   y_magne   z_magne   x_gyros   y_gyros  \\\n",
       "0  1.033093  ...  0.109450 -0.465064 -0.372143  0.303976  1.588786  0.558175   \n",
       "1  0.991529  ... -0.080417 -0.153359  1.292601  0.500978  1.588793  0.248723   \n",
       "\n",
       "   z_gyros  x_rotate  y_rotate  z_rotate  \n",
       "0  0.87003  0.073650  1.198900  0.804293  \n",
       "1  0.80081  0.203768  1.516082  0.877682  \n",
       "\n",
       "[2 rows x 24 columns]"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ss2 = StandardScaler()\n",
    "ss2.fit(train_all.loc[:,['timestamp']])\n",
    "train_all.loc[:,['timestamp']] = ss2.transform(train_all.loc[:,['timestamp']])\n",
    "test_all.loc[:,['timestamp']] = ss2.transform(test_all.loc[:,['timestamp']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all_floor_min = train_all.floor.min()\n",
    "# train_all_floor_max = train_all.floor.max()\n",
    "# train_all['floor'] = (train_all['floor']-train_all_floor_min)/(train_all_floor_max-train_all_floor_min)\n",
    "# test_all['floor'] = (test_all['floor']-train_all_floor_min)/(train_all_floor_max-train_all_floor_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "sitelist = list(sorted(set(train_all.site)))\n",
    "sitedict = dict(zip(sitelist,range(len(sitelist))))\n",
    "train_all['site_id'] = train_all['site'].apply(lambda x: sitedict[x])\n",
    "test_all['site_id'] = test_all['site'].apply(lambda x: sitedict[x])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MCRMSE(y_true, y_pred):\n",
    "    colwise_mse = tf.reduce_mean(tf.square(y_true - y_pred), axis=1)\n",
    "    return tf.reduce_mean(tf.sqrt(colwise_mse), axis=1)\n",
    "\n",
    "def gru_layer(hidden_dim, dropout):\n",
    "    return L.Bidirectional(L.GRU(\n",
    "        hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer='orthogonal'))\n",
    "\n",
    "def pandas_list_to_array(df):\n",
    "    \"\"\"\n",
    "    Input: dataframe of shape (x, y), containing list of length l\n",
    "    Return: np.array of shape (x, l, y)\n",
    "    \"\"\"\n",
    "    \n",
    "    return np.transpose(\n",
    "        np.array(df.values.tolist()),\n",
    "        (0, 2, 1)\n",
    "    )\n",
    "\n",
    "def preprocess_inputs(df, cols=['ssid','bssid', 'rssi']):\n",
    "    return pandas_list_to_array(\n",
    "        df[cols]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model_mix(sid_size,bssid_size,site_size, seq_len=100, pred_len=2, dropout=0.2, \n",
    "                sp_dropout=0.1, embed_dim=64, hidden_dim=128, n_layers=3,lr=0.001):\n",
    "    inputs = L.Input(shape=(seq_len, 3))\n",
    "    input_time = L.Input(shape = (4+12,))\n",
    "    input_site = L.Input(shape = (1,))\n",
    "        \n",
    "    categorical_fea1 = inputs[:, :, :1]\n",
    "    categorical_fea2 = inputs[:, :, 1:2]\n",
    "    numerical_fea = inputs[:, :, 2:]\n",
    "    \n",
    "\n",
    "    embed = L.Embedding(input_dim=sid_size, output_dim=embed_dim)(categorical_fea1)\n",
    "    reshaped = tf.reshape(embed, shape=(-1, embed.shape[1],  embed.shape[2] * embed.shape[3]))\n",
    "    reshaped = L.SpatialDropout1D(sp_dropout)(reshaped)\n",
    "    \n",
    "    embed2 = L.Embedding(input_dim=bssid_size, output_dim=embed_dim)(categorical_fea2)\n",
    "    reshaped2 = tf.reshape(embed2, shape=(-1, embed2.shape[1],  embed2.shape[2] * embed2.shape[3]))\n",
    "    reshaped2 = L.SpatialDropout1D(sp_dropout)(reshaped2)\n",
    "    \n",
    "    \n",
    "    hidden = L.concatenate([reshaped, reshaped2, numerical_fea], axis=2)\n",
    "    \n",
    "    for x in range(n_layers):\n",
    "        hidden = gru_layer(hidden_dim, dropout)(hidden)\n",
    "    \n",
    "    # Since we are only making predictions on the first part of each sequence, \n",
    "    # we have to truncate it\n",
    "    truncated = hidden[:, :pred_len]\n",
    "    truncated = L.Flatten()(truncated)\n",
    "    \n",
    "    embed_site = L.Embedding(input_dim=site_size, output_dim=1)(input_site)\n",
    "    embed_site = L.Flatten()(embed_site)\n",
    "        \n",
    "    truncated = L.concatenate([truncated, input_time,embed_site], axis=1)\n",
    "    \n",
    "    #out = L.Dense(32, activation='linear')(truncated)\n",
    "\n",
    "    out = L.Dense(2, activation='linear')(truncated)\n",
    "        \n",
    "    model = tf.keras.Model(inputs=[inputs,input_time,input_site], outputs=out)\n",
    "    model.compile(tf.optimizers.Adam(lr), loss='mse')\n",
    "    \n",
    "    return model\n",
    "\n",
    "def get_embed_size(n_cat):\n",
    "    return min(600, round(1.6 * n_cat ** .56))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# with open('train_all.pickle','wb') as fw:\n",
    "#     pickle.dump(train_all,fw)\n",
    "# with open('test_all.pickle','wb') as fw:\n",
    "#     pickle.dump(test_all,fw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['timestamp', 'ssid', 'bssid', 'rssi', 'path', 'floorNo', 'wifi_len',\n",
       "       'wifi_mean', 'wifi_median', 'wifi_std', 'site', 'ts_waypoint', 'x_acce',\n",
       "       'y_acce', 'z_acce', 'x_magne', 'y_magne', 'z_magne', 'x_gyros',\n",
       "       'y_gyros', 'z_gyros', 'x_rotate', 'y_rotate', 'z_rotate', 'site_id'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_all.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "ss = train_all[train_all.path=='5dd4b80627889b0006b7772d']\n",
    "ss.fillna(method='bfill',inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>...</th>\n",
       "      <th>y_magne</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>fold</th>\n",
       "      <th>site_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>224653</th>\n",
       "      <td>0.707866</td>\n",
       "      <td>[15433, 15433, 13482, 19396, 13482, 19396, 154...</td>\n",
       "      <td>[63855, 31672, 22276, 58558, 30991, 19176, 331...</td>\n",
       "      <td>[2.4171285806916414, 1.826730918477428, 1.8267...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.262</td>\n",
       "      <td>-0.325291</td>\n",
       "      <td>...</td>\n",
       "      <td>0.713641</td>\n",
       "      <td>0.364087</td>\n",
       "      <td>1.094598</td>\n",
       "      <td>-1.136477</td>\n",
       "      <td>0.824800</td>\n",
       "      <td>-1.021430</td>\n",
       "      <td>0.694993</td>\n",
       "      <td>1.354620</td>\n",
       "      <td>0.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224654</th>\n",
       "      <td>0.707867</td>\n",
       "      <td>[15433, 15433, 13482, 19396, 11875, 15433, 976...</td>\n",
       "      <td>[31672, 63855, 30991, 19176, 31883, 33131, 450...</td>\n",
       "      <td>[2.318728970322606, 1.6299316977393568, 1.3347...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.260</td>\n",
       "      <td>-0.363039</td>\n",
       "      <td>...</td>\n",
       "      <td>0.713641</td>\n",
       "      <td>0.364087</td>\n",
       "      <td>1.094598</td>\n",
       "      <td>-1.136477</td>\n",
       "      <td>0.824800</td>\n",
       "      <td>-1.021430</td>\n",
       "      <td>0.694993</td>\n",
       "      <td>1.354620</td>\n",
       "      <td>6.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224655</th>\n",
       "      <td>0.707867</td>\n",
       "      <td>[13482, 15433, 19396, 11875, 9761, 15433, 1187...</td>\n",
       "      <td>[30991, 63855, 19176, 31883, 45009, 31672, 389...</td>\n",
       "      <td>[1.3347328666322502, 1.3347328666322502, 1.236...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.260</td>\n",
       "      <td>-0.392559</td>\n",
       "      <td>...</td>\n",
       "      <td>1.250962</td>\n",
       "      <td>1.443087</td>\n",
       "      <td>0.997709</td>\n",
       "      <td>-1.189775</td>\n",
       "      <td>0.049868</td>\n",
       "      <td>-1.275551</td>\n",
       "      <td>-0.050002</td>\n",
       "      <td>1.369476</td>\n",
       "      <td>8.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224656</th>\n",
       "      <td>0.707868</td>\n",
       "      <td>[15433, 11875, 9761, 11875, 13482, 19396, 1503...</td>\n",
       "      <td>[63855, 31883, 45009, 38916, 42558, 59902, 172...</td>\n",
       "      <td>[1.3347328666322502, 0.8427348147870724, 0.645...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.258</td>\n",
       "      <td>-0.459344</td>\n",
       "      <td>...</td>\n",
       "      <td>1.017085</td>\n",
       "      <td>-1.334331</td>\n",
       "      <td>0.750805</td>\n",
       "      <td>-0.969883</td>\n",
       "      <td>-0.647448</td>\n",
       "      <td>-1.030683</td>\n",
       "      <td>-0.485144</td>\n",
       "      <td>1.381195</td>\n",
       "      <td>4.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224657</th>\n",
       "      <td>0.707869</td>\n",
       "      <td>[15433, 11875, 9761, 11875, 13482, 19396, 1348...</td>\n",
       "      <td>[63855, 31883, 45009, 38916, 42558, 59902, 265...</td>\n",
       "      <td>[1.3347328666322502, 0.8427348147870724, 0.645...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.256</td>\n",
       "      <td>-0.468747</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.246060</td>\n",
       "      <td>1.202760</td>\n",
       "      <td>1.125877</td>\n",
       "      <td>-1.203099</td>\n",
       "      <td>0.514863</td>\n",
       "      <td>-0.930176</td>\n",
       "      <td>0.546333</td>\n",
       "      <td>1.365165</td>\n",
       "      <td>8.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224658</th>\n",
       "      <td>0.707869</td>\n",
       "      <td>[15433, 11875, 9761, 13482, 19396, 13482, 1939...</td>\n",
       "      <td>[63855, 31883, 45009, 26529, 62405, 15051, 153...</td>\n",
       "      <td>[1.3347328666322502, 0.8427348147870724, 0.645...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.234</td>\n",
       "      <td>-0.422163</td>\n",
       "      <td>...</td>\n",
       "      <td>0.657207</td>\n",
       "      <td>0.425335</td>\n",
       "      <td>1.844603</td>\n",
       "      <td>-1.069854</td>\n",
       "      <td>-0.492567</td>\n",
       "      <td>-1.235080</td>\n",
       "      <td>0.291654</td>\n",
       "      <td>1.367361</td>\n",
       "      <td>2.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224659</th>\n",
       "      <td>0.707870</td>\n",
       "      <td>[15433, 11875, 13482, 19396, 9761, 13482, 1939...</td>\n",
       "      <td>[63855, 31883, 26529, 62405, 45009, 15051, 153...</td>\n",
       "      <td>[1.3347328666322502, 0.8427348147870724, 0.842...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.238</td>\n",
       "      <td>-0.380232</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.634699</td>\n",
       "      <td>0.024486</td>\n",
       "      <td>1.094598</td>\n",
       "      <td>-1.536361</td>\n",
       "      <td>-0.260070</td>\n",
       "      <td>-1.412949</td>\n",
       "      <td>-0.177800</td>\n",
       "      <td>1.355965</td>\n",
       "      <td>7.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224660</th>\n",
       "      <td>0.707871</td>\n",
       "      <td>[13482, 19396, 19396, 15433, 14772, 11875, 193...</td>\n",
       "      <td>[15051, 12689, 15371, 63855, 31025, 31883, 126...</td>\n",
       "      <td>[1.5315320873703213, 1.5315320873703213, 1.531...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.246</td>\n",
       "      <td>-0.278861</td>\n",
       "      <td>...</td>\n",
       "      <td>0.412877</td>\n",
       "      <td>0.254555</td>\n",
       "      <td>0.625828</td>\n",
       "      <td>-1.536361</td>\n",
       "      <td>0.359982</td>\n",
       "      <td>-1.125339</td>\n",
       "      <td>0.728169</td>\n",
       "      <td>1.377426</td>\n",
       "      <td>4.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224661</th>\n",
       "      <td>0.707871</td>\n",
       "      <td>[19396, 19396, 13482, 13482, 19396, 11875, 147...</td>\n",
       "      <td>[12690, 36864, 37924, 26529, 62405, 31883, 310...</td>\n",
       "      <td>[1.4331324770012857, 1.4331324770012857, 1.433...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.258</td>\n",
       "      <td>-0.282377</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.508025</td>\n",
       "      <td>3.236964</td>\n",
       "      <td>-0.084674</td>\n",
       "      <td>-1.573146</td>\n",
       "      <td>0.058029</td>\n",
       "      <td>-1.486332</td>\n",
       "      <td>0.197214</td>\n",
       "      <td>1.448679</td>\n",
       "      <td>1.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224662</th>\n",
       "      <td>0.707872</td>\n",
       "      <td>[19396, 19396, 13482, 13482, 19396, 19396, 134...</td>\n",
       "      <td>[12690, 36864, 37924, 26529, 62405, 15371, 150...</td>\n",
       "      <td>[1.6299316977393568, 1.6299316977393568, 1.629...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.258</td>\n",
       "      <td>-0.258731</td>\n",
       "      <td>...</td>\n",
       "      <td>0.569195</td>\n",
       "      <td>0.178858</td>\n",
       "      <td>-1.374161</td>\n",
       "      <td>-0.536651</td>\n",
       "      <td>-0.570007</td>\n",
       "      <td>0.243419</td>\n",
       "      <td>-0.569259</td>\n",
       "      <td>-1.188966</td>\n",
       "      <td>5.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224663</th>\n",
       "      <td>0.707873</td>\n",
       "      <td>[13482, 19396, 19396, 19396, 13482, 19396, 118...</td>\n",
       "      <td>[26529, 62405, 12690, 36864, 37924, 12689, 318...</td>\n",
       "      <td>[1.9251305288464635, 1.9251305288464635, 1.728...</td>\n",
       "      <td>5dd4b80627889b0006b7772d</td>\n",
       "      <td>1.939098</td>\n",
       "      <td>F6</td>\n",
       "      <td>5dbc1d84c1eb61796cf7c010</td>\n",
       "      <td>0.256</td>\n",
       "      <td>-0.281942</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.559328</td>\n",
       "      <td>-4.136266</td>\n",
       "      <td>1.438321</td>\n",
       "      <td>-0.803216</td>\n",
       "      <td>0.514863</td>\n",
       "      <td>-1.137083</td>\n",
       "      <td>0.331922</td>\n",
       "      <td>1.231725</td>\n",
       "      <td>1.0</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>11 rows × 29 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        timestamp                                               ssid  \\\n",
       "224653   0.707866  [15433, 15433, 13482, 19396, 13482, 19396, 154...   \n",
       "224654   0.707867  [15433, 15433, 13482, 19396, 11875, 15433, 976...   \n",
       "224655   0.707867  [13482, 15433, 19396, 11875, 9761, 15433, 1187...   \n",
       "224656   0.707868  [15433, 11875, 9761, 11875, 13482, 19396, 1503...   \n",
       "224657   0.707869  [15433, 11875, 9761, 11875, 13482, 19396, 1348...   \n",
       "224658   0.707869  [15433, 11875, 9761, 13482, 19396, 13482, 1939...   \n",
       "224659   0.707870  [15433, 11875, 13482, 19396, 9761, 13482, 1939...   \n",
       "224660   0.707871  [13482, 19396, 19396, 15433, 14772, 11875, 193...   \n",
       "224661   0.707871  [19396, 19396, 13482, 13482, 19396, 11875, 147...   \n",
       "224662   0.707872  [19396, 19396, 13482, 13482, 19396, 19396, 134...   \n",
       "224663   0.707873  [13482, 19396, 19396, 19396, 13482, 19396, 118...   \n",
       "\n",
       "                                                    bssid  \\\n",
       "224653  [63855, 31672, 22276, 58558, 30991, 19176, 331...   \n",
       "224654  [31672, 63855, 30991, 19176, 31883, 33131, 450...   \n",
       "224655  [30991, 63855, 19176, 31883, 45009, 31672, 389...   \n",
       "224656  [63855, 31883, 45009, 38916, 42558, 59902, 172...   \n",
       "224657  [63855, 31883, 45009, 38916, 42558, 59902, 265...   \n",
       "224658  [63855, 31883, 45009, 26529, 62405, 15051, 153...   \n",
       "224659  [63855, 31883, 26529, 62405, 45009, 15051, 153...   \n",
       "224660  [15051, 12689, 15371, 63855, 31025, 31883, 126...   \n",
       "224661  [12690, 36864, 37924, 26529, 62405, 31883, 310...   \n",
       "224662  [12690, 36864, 37924, 26529, 62405, 15371, 150...   \n",
       "224663  [26529, 62405, 12690, 36864, 37924, 12689, 318...   \n",
       "\n",
       "                                                     rssi  \\\n",
       "224653  [2.4171285806916414, 1.826730918477428, 1.8267...   \n",
       "224654  [2.318728970322606, 1.6299316977393568, 1.3347...   \n",
       "224655  [1.3347328666322502, 1.3347328666322502, 1.236...   \n",
       "224656  [1.3347328666322502, 0.8427348147870724, 0.645...   \n",
       "224657  [1.3347328666322502, 0.8427348147870724, 0.645...   \n",
       "224658  [1.3347328666322502, 0.8427348147870724, 0.645...   \n",
       "224659  [1.3347328666322502, 0.8427348147870724, 0.842...   \n",
       "224660  [1.5315320873703213, 1.5315320873703213, 1.531...   \n",
       "224661  [1.4331324770012857, 1.4331324770012857, 1.433...   \n",
       "224662  [1.6299316977393568, 1.6299316977393568, 1.629...   \n",
       "224663  [1.9251305288464635, 1.9251305288464635, 1.728...   \n",
       "\n",
       "                            path   floorNo floor                      site  \\\n",
       "224653  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224654  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224655  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224656  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224657  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224658  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224659  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224660  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224661  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224662  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "224663  5dd4b80627889b0006b7772d  1.939098    F6  5dbc1d84c1eb61796cf7c010   \n",
       "\n",
       "        wifi_len  wifi_mean  ...   y_magne   z_magne   x_gyros   y_gyros  \\\n",
       "224653     0.262  -0.325291  ...  0.713641  0.364087  1.094598 -1.136477   \n",
       "224654     0.260  -0.363039  ...  0.713641  0.364087  1.094598 -1.136477   \n",
       "224655     0.260  -0.392559  ...  1.250962  1.443087  0.997709 -1.189775   \n",
       "224656     0.258  -0.459344  ...  1.017085 -1.334331  0.750805 -0.969883   \n",
       "224657     0.256  -0.468747  ... -0.246060  1.202760  1.125877 -1.203099   \n",
       "224658     0.234  -0.422163  ...  0.657207  0.425335  1.844603 -1.069854   \n",
       "224659     0.238  -0.380232  ... -0.634699  0.024486  1.094598 -1.536361   \n",
       "224660     0.246  -0.278861  ...  0.412877  0.254555  0.625828 -1.536361   \n",
       "224661     0.258  -0.282377  ... -0.508025  3.236964 -0.084674 -1.573146   \n",
       "224662     0.258  -0.258731  ...  0.569195  0.178858 -1.374161 -0.536651   \n",
       "224663     0.256  -0.281942  ... -0.559328 -4.136266  1.438321 -0.803216   \n",
       "\n",
       "         z_gyros  x_rotate  y_rotate  z_rotate  fold  site_id  \n",
       "224653  0.824800 -1.021430  0.694993  1.354620   0.0       22  \n",
       "224654  0.824800 -1.021430  0.694993  1.354620   6.0       22  \n",
       "224655  0.049868 -1.275551 -0.050002  1.369476   8.0       22  \n",
       "224656 -0.647448 -1.030683 -0.485144  1.381195   4.0       22  \n",
       "224657  0.514863 -0.930176  0.546333  1.365165   8.0       22  \n",
       "224658 -0.492567 -1.235080  0.291654  1.367361   2.0       22  \n",
       "224659 -0.260070 -1.412949 -0.177800  1.355965   7.0       22  \n",
       "224660  0.359982 -1.125339  0.728169  1.377426   4.0       22  \n",
       "224661  0.058029 -1.486332  0.197214  1.448679   1.0       22  \n",
       "224662 -0.570007  0.243419 -0.569259 -1.188966   5.0       22  \n",
       "224663  0.514863 -1.137083  0.331922  1.231725   1.0       22  \n",
       "\n",
       "[11 rows x 29 columns]"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Series([], Name: path, dtype: int64)"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all[train_all.x_acce.isna()].path.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "begin fold: 0\n",
      "fold 0 1.5483097549014528\n",
      "151.3674636340884\n",
      "elasped time: 85.34241080284119\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "t1 = time.time()\n",
    "pred_cols = ['x','y']\n",
    "train_inputs = preprocess_inputs(train_all,cols=['ssid', 'bssid', 'rssi'])\n",
    "train_inputs_time = train_all[['timestamp','floorNo','wifi_len','wifi_mean', 'x_acce',\n",
    "       'y_acce', 'z_acce', 'x_magne', 'y_magne', 'z_magne', 'x_gyros',\n",
    "       'y_gyros', 'z_gyros', 'x_rotate', 'y_rotate', 'z_rotate']].values\n",
    "train_inputs_site = train_all['site_id'].values\n",
    "train_labels = train_all[pred_cols].values\n",
    "test_inputs = preprocess_inputs(test_all,cols=['ssid','bssid', 'rssi'])\n",
    "test_inputs_time = test_all[['timestamp','floorNo','wifi_len','wifi_mean', 'x_acce',\n",
    "       'y_acce', 'z_acce', 'x_magne', 'y_magne', 'z_magne', 'x_gyros',\n",
    "       'y_gyros', 'z_gyros', 'x_rotate', 'y_rotate', 'z_rotate']].values\n",
    "test_inputs_site = test_all['site_id'].values\n",
    "\n",
    "\n",
    "    \n",
    "    \n",
    "x_test = test_inputs\n",
    "x_test_time = test_inputs_time\n",
    "x_test_site = test_inputs_site\n",
    "\n",
    "oof_xy = np.zeros(train_labels.shape)\n",
    "y_test_pred = 0\n",
    "for fold_id in range(N_SPLITS):\n",
    "    trn_idx = train_all[train_all.fold!=fold_id].index.tolist()\n",
    "    val_idx = train_all[train_all.fold==fold_id].index.tolist()\n",
    "    print('begin fold:',fold_id)\n",
    "    x_train, x_val = train_inputs[trn_idx],train_inputs[val_idx]\n",
    "    x_train_time, x_val_time = train_inputs_time[trn_idx],train_inputs_time[val_idx]\n",
    "    x_train_site, x_val_site = train_inputs_site[trn_idx],train_inputs_site[val_idx]\n",
    "    y_train, y_val = train_labels[trn_idx],train_labels[val_idx]\n",
    "    \n",
    "    model = build_model_mix(len(ssiddict),len(bssiddict),len(sitedict),seqlen,lr=0.001)\n",
    "\n",
    "    history = model.fit(\n",
    "            [x_train,x_train_time,x_train_site], y_train,\n",
    "            validation_data=([x_val,x_val_time,x_val_site], y_val),\n",
    "            batch_size=128,\n",
    "            epochs=100,\n",
    "            verbose=1,\n",
    "            callbacks=[\n",
    "                tf.keras.callbacks.ReduceLROnPlateau(patience=5),\n",
    "                tf.keras.callbacks.ModelCheckpoint('rnn_model_wifisensor/model_fold_{}.h5'.format(fold_id)),\n",
    "                tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-4,\n",
    "                                  patience=5, mode='min', restore_best_weights=True)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "#     model.load_weights('rnn_model_wifisensor/model_fold_{}.h5'.format(fold_id))\n",
    "\n",
    "\n",
    "    y_val_pred = model.predict([x_val,x_val_time,x_val_site])\n",
    "    y_test_pred += model.predict([x_test,x_test_time,x_test_site])\n",
    "    oof_xy[val_idx] = y_val_pred\n",
    "    print('fold',fold_id, np.mean(np.sqrt(np.sum((y_val-y_val_pred)**2,axis=1))))\n",
    "    break\n",
    "y_test_pred = y_test_pred/(fold_id + 1)    \n",
    "train_labels_inv = (pd.DataFrame(train_labels[:,:],columns = ['x','y']))\n",
    "oof_xy_pred_inv = (pd.DataFrame(oof_xy[:,:],columns = ['x','y']))\n",
    "y_test_pred_inv = (pd.DataFrame(y_test_pred[:,:],columns = ['x','y']))  \n",
    "print(np.mean(np.sqrt(np.sum((train_labels_inv-oof_xy_pred_inv)**2,axis=1))))\n",
    "\n",
    "t2 = time.time()\n",
    "print('elasped time:', t2 - t1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fold 0 1.5483097549014528\n"
     ]
    }
   ],
   "source": [
    "print('fold',fold_id, np.mean(np.sqrt(np.sum((y_val-y_val_pred)**2,axis=1))))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_all[['x','y']] = y_test_pred_inv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>...</th>\n",
       "      <th>z_magne</th>\n",
       "      <th>x_gyros</th>\n",
       "      <th>y_gyros</th>\n",
       "      <th>z_gyros</th>\n",
       "      <th>x_rotate</th>\n",
       "      <th>y_rotate</th>\n",
       "      <th>z_rotate</th>\n",
       "      <th>site_id</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.345764</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>...</td>\n",
       "      <td>0.303976</td>\n",
       "      <td>1.588786</td>\n",
       "      <td>0.558175</td>\n",
       "      <td>0.87003</td>\n",
       "      <td>0.073650</td>\n",
       "      <td>1.198900</td>\n",
       "      <td>0.804293</td>\n",
       "      <td>19</td>\n",
       "      <td>73.382919</td>\n",
       "      <td>88.722977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.345765</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>...</td>\n",
       "      <td>0.500978</td>\n",
       "      <td>1.588793</td>\n",
       "      <td>0.248723</td>\n",
       "      <td>0.80081</td>\n",
       "      <td>0.203768</td>\n",
       "      <td>1.516082</td>\n",
       "      <td>0.877682</td>\n",
       "      <td>19</td>\n",
       "      <td>73.456726</td>\n",
       "      <td>87.362114</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                                               ssid  \\\n",
       "0   0.345764  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1   0.345765  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "\n",
       "   wifi_std  ...   z_magne   x_gyros   y_gyros  z_gyros  x_rotate  y_rotate  \\\n",
       "0  1.033093  ...  0.303976  1.588786  0.558175  0.87003  0.073650  1.198900   \n",
       "1  0.991529  ...  0.500978  1.588793  0.248723  0.80081  0.203768  1.516082   \n",
       "\n",
       "   z_rotate  site_id          x          y  \n",
       "0  0.804293       19  73.382919  88.722977  \n",
       "1  0.877682       19  73.456726  87.362114  \n",
       "\n",
       "[2 rows x 27 columns]"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>t1_wifi</th>\n",
       "      <th>path_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.345764</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>73.382919</td>\n",
       "      <td>88.722977</td>\n",
       "      <td>1180.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.345765</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>73.456726</td>\n",
       "      <td>87.362114</td>\n",
       "      <td>3048.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>72.727478</td>\n",
       "      <td>85.721558</td>\n",
       "      <td>4924.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>72.815376</td>\n",
       "      <td>83.148605</td>\n",
       "      <td>6816.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.345767</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>73.723251</td>\n",
       "      <td>87.773392</td>\n",
       "      <td>8693.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                      path                      site          x  \\\n",
       "0   0.345764  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  73.382919   \n",
       "1   0.345765  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  73.456726   \n",
       "2   0.345766  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  72.727478   \n",
       "3   0.345766  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  72.815376   \n",
       "4   0.345767  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  73.723251   \n",
       "\n",
       "           y  t1_wifi                                            path_id  \n",
       "0  88.722977   1180.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "1  87.362114   3048.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "2  85.721558   4924.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "3  83.148605   6816.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "4  87.773392   8693.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  "
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = test_all[['timestamp','path','site','x','y']]\n",
    "result['t1_wifi'] = ss2.inverse_transform(result['timestamp'])\n",
    "\n",
    "result['t1_wifi'] = [xx-testpath2gap[yy][0] for (xx,yy) in zip(result['t1_wifi'],result['path'])]\n",
    "result['path_id'] = result['site']+'_'+result['path']\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>t1_wifi</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>path_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345764</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>73.382919</td>\n",
       "      <td>88.722977</td>\n",
       "      <td>1180.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345765</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>73.456726</td>\n",
       "      <td>87.362114</td>\n",
       "      <td>3048.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>72.727478</td>\n",
       "      <td>85.721558</td>\n",
       "      <td>4924.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>72.815376</td>\n",
       "      <td>83.148605</td>\n",
       "      <td>6816.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345767</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>73.723251</td>\n",
       "      <td>87.773392</td>\n",
       "      <td>8693.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   timestamp  \\\n",
       "path_id                                                        \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345764   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345765   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345766   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345766   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345767   \n",
       "\n",
       "                                                                       path  \\\n",
       "path_id                                                                       \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "\n",
       "                                                                       site  \\\n",
       "path_id                                                                       \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "\n",
       "                                                           x          y  \\\n",
       "path_id                                                                   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  73.382919  88.722977   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  73.456726  87.362114   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  72.727478  85.721558   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  72.815376  83.148605   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  73.723251  87.773392   \n",
       "\n",
       "                                                   t1_wifi  \n",
       "path_id                                                     \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   1180.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   3048.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   4924.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   6816.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   8693.0  "
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.set_index('path_id', inplace=True)\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.transform import Rotation as R\n",
    "from PIL import Image\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import plotly.graph_objs as go\n",
    "from pathlib import Path\n",
    "import scipy.signal as signal\n",
    "import json\n",
    "import seaborn as sns # visualization\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import matplotlib.pyplot as plt  # visualization\n",
    "import numpy as np  # linear algebra\n",
    "import random\n",
    "import pandas as pd\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "\n",
    "def split_ts_seq(ts_seq, sep_ts):\n",
    "    \"\"\"\n",
    "\n",
    "    :param ts_seq:\n",
    "    :param sep_ts:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    tss = ts_seq[:, 0].astype(float)\n",
    "    unique_sep_ts = np.unique(sep_ts)\n",
    "    ts_seqs = []\n",
    "    start_index = 0\n",
    "    for i in range(0, unique_sep_ts.shape[0]):\n",
    "        end_index = np.searchsorted(tss, unique_sep_ts[i], side='right')\n",
    "        if start_index == end_index:\n",
    "            continue\n",
    "        ts_seqs.append(ts_seq[start_index:end_index, :].copy())\n",
    "        start_index = end_index\n",
    "\n",
    "    # tail data\n",
    "    if start_index < ts_seq.shape[0]:\n",
    "        ts_seqs.append(ts_seq[start_index:, :].copy())\n",
    "\n",
    "    return ts_seqs\n",
    "\n",
    "\n",
    "def correct_trajectory(original_xys, end_xy):\n",
    "    \"\"\"\n",
    "\n",
    "    :param original_xys: numpy ndarray, shape(N, 2)\n",
    "    :param end_xy: numpy ndarray, shape(1, 2)\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    corrected_xys = np.zeros((0, 2))\n",
    "\n",
    "    A = original_xys[0, :]\n",
    "    B = end_xy\n",
    "    Bp = original_xys[-1, :]\n",
    "\n",
    "    angle_BAX = np.arctan2(B[1] - A[1], B[0] - A[0])\n",
    "    angle_BpAX = np.arctan2(Bp[1] - A[1], Bp[0] - A[0])\n",
    "    angle_BpAB = angle_BpAX - angle_BAX\n",
    "    AB = np.sqrt(np.sum((B - A) ** 2))\n",
    "    ABp = np.sqrt(np.sum((Bp - A) ** 2))\n",
    "\n",
    "    corrected_xys = np.append(corrected_xys, [A], 0)\n",
    "    for i in np.arange(1, np.size(original_xys, 0)):\n",
    "        angle_CpAX = np.arctan2(original_xys[i, 1] - A[1], original_xys[i, 0] - A[0])\n",
    "\n",
    "        angle_CAX = angle_CpAX - angle_BpAB\n",
    "\n",
    "        ACp = np.sqrt(np.sum((original_xys[i, :] - A) ** 2))\n",
    "\n",
    "        AC = ACp * AB / ABp\n",
    "\n",
    "        delta_C = np.array([AC * np.cos(angle_CAX), AC * np.sin(angle_CAX)])\n",
    "\n",
    "        C = delta_C + A\n",
    "\n",
    "        corrected_xys = np.append(corrected_xys, [C], 0)\n",
    "\n",
    "    return corrected_xys\n",
    "\n",
    "\n",
    "def correct_positions(rel_positions, reference_positions):\n",
    "    \"\"\"\n",
    "\n",
    "    :param rel_positions:\n",
    "    :param reference_positions:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    rel_positions_list = split_ts_seq(rel_positions, reference_positions[:, 0])\n",
    "    if len(rel_positions_list) != reference_positions.shape[0] - 1:\n",
    "        # print(f'Rel positions list size: {len(rel_positions_list)}, ref positions size: {reference_positions.shape[0]}')\n",
    "        del rel_positions_list[-1]\n",
    "    assert len(rel_positions_list) == reference_positions.shape[0] - 1\n",
    "\n",
    "    corrected_positions = np.zeros((0, 3))\n",
    "    for i, rel_ps in enumerate(rel_positions_list):\n",
    "        start_position = reference_positions[i]\n",
    "        end_position = reference_positions[i + 1]\n",
    "        abs_ps = np.zeros(rel_ps.shape)\n",
    "        abs_ps[:, 0] = rel_ps[:, 0]\n",
    "        # abs_ps[:, 1:3] = rel_ps[:, 1:3] + start_position[1:3]\n",
    "        abs_ps[0, 1:3] = rel_ps[0, 1:3] + start_position[1:3]\n",
    "        for j in range(1, rel_ps.shape[0]):\n",
    "            abs_ps[j, 1:3] = abs_ps[j-1, 1:3] + rel_ps[j, 1:3]\n",
    "        abs_ps = np.insert(abs_ps, 0, start_position, axis=0)\n",
    "        corrected_xys = correct_trajectory(abs_ps[:, 1:3], end_position[1:3])\n",
    "        corrected_ps = np.column_stack((abs_ps[:, 0], corrected_xys))\n",
    "        if i == 0:\n",
    "            corrected_positions = np.append(corrected_positions, corrected_ps, axis=0)\n",
    "        else:\n",
    "            corrected_positions = np.append(corrected_positions, corrected_ps[1:], axis=0)\n",
    "\n",
    "    corrected_positions = np.array(corrected_positions)\n",
    "\n",
    "    return corrected_positions\n",
    "\n",
    "\n",
    "def init_parameters_filter(sample_freq, warmup_data, cut_off_freq=2):\n",
    "    order = 4\n",
    "    filter_b, filter_a = signal.butter(order, cut_off_freq / (sample_freq / 2), 'low', False)\n",
    "    zf = signal.lfilter_zi(filter_b, filter_a)\n",
    "    _, zf = signal.lfilter(filter_b, filter_a, warmup_data, zi=zf)\n",
    "    _, filter_zf = signal.lfilter(filter_b, filter_a, warmup_data, zi=zf)\n",
    "\n",
    "    return filter_b, filter_a, filter_zf\n",
    "\n",
    "\n",
    "def get_rotation_matrix_from_vector(rotation_vector):\n",
    "    q1 = rotation_vector[0]\n",
    "    q2 = rotation_vector[1]\n",
    "    q3 = rotation_vector[2]\n",
    "\n",
    "    if rotation_vector.size >= 4:\n",
    "        q0 = rotation_vector[3]\n",
    "    else:\n",
    "        q0 = 1 - q1*q1 - q2*q2 - q3*q3\n",
    "        if q0 > 0:\n",
    "            q0 = np.sqrt(q0)\n",
    "        else:\n",
    "            q0 = 0\n",
    "\n",
    "    sq_q1 = 2 * q1 * q1\n",
    "    sq_q2 = 2 * q2 * q2\n",
    "    sq_q3 = 2 * q3 * q3\n",
    "    q1_q2 = 2 * q1 * q2\n",
    "    q3_q0 = 2 * q3 * q0\n",
    "    q1_q3 = 2 * q1 * q3\n",
    "    q2_q0 = 2 * q2 * q0\n",
    "    q2_q3 = 2 * q2 * q3\n",
    "    q1_q0 = 2 * q1 * q0\n",
    "\n",
    "    R = np.zeros((9,))\n",
    "    if R.size == 9:\n",
    "        R[0] = 1 - sq_q2 - sq_q3\n",
    "        R[1] = q1_q2 - q3_q0\n",
    "        R[2] = q1_q3 + q2_q0\n",
    "\n",
    "        R[3] = q1_q2 + q3_q0\n",
    "        R[4] = 1 - sq_q1 - sq_q3\n",
    "        R[5] = q2_q3 - q1_q0\n",
    "\n",
    "        R[6] = q1_q3 - q2_q0\n",
    "        R[7] = q2_q3 + q1_q0\n",
    "        R[8] = 1 - sq_q1 - sq_q2\n",
    "\n",
    "        R = np.reshape(R, (3, 3))\n",
    "    elif R.size == 16:\n",
    "        R[0] = 1 - sq_q2 - sq_q3\n",
    "        R[1] = q1_q2 - q3_q0\n",
    "        R[2] = q1_q3 + q2_q0\n",
    "        R[3] = 0.0\n",
    "\n",
    "        R[4] = q1_q2 + q3_q0\n",
    "        R[5] = 1 - sq_q1 - sq_q3\n",
    "        R[6] = q2_q3 - q1_q0\n",
    "        R[7] = 0.0\n",
    "\n",
    "        R[8] = q1_q3 - q2_q0\n",
    "        R[9] = q2_q3 + q1_q0\n",
    "        R[10] = 1 - sq_q1 - sq_q2\n",
    "        R[11] = 0.0\n",
    "\n",
    "        R[12] = R[13] = R[14] = 0.0\n",
    "        R[15] = 1.0\n",
    "\n",
    "        R = np.reshape(R, (4, 4))\n",
    "\n",
    "    return R\n",
    "\n",
    "\n",
    "def get_orientation(R):\n",
    "    flat_R = R.flatten()\n",
    "    values = np.zeros((3,))\n",
    "    if np.size(flat_R) == 9:\n",
    "        values[0] = np.arctan2(flat_R[1], flat_R[4])\n",
    "        values[1] = np.arcsin(-flat_R[7])\n",
    "        values[2] = np.arctan2(-flat_R[6], flat_R[8])\n",
    "    else:\n",
    "        values[0] = np.arctan2(flat_R[1], flat_R[5])\n",
    "        values[1] = np.arcsin(-flat_R[9])\n",
    "        values[2] = np.arctan2(-flat_R[8], flat_R[10])\n",
    "\n",
    "    return values\n",
    "\n",
    "\n",
    "def compute_steps(acce_datas):\n",
    "    step_timestamps = np.array([])\n",
    "    step_indexs = np.array([], dtype=int)\n",
    "    step_acce_max_mins = np.zeros((0, 4))\n",
    "    sample_freq = 50\n",
    "    window_size = 22\n",
    "    low_acce_mag = 0.6\n",
    "    step_criterion = 1\n",
    "    interval_threshold = 250\n",
    "\n",
    "    acce_max = np.zeros((2,))\n",
    "    acce_min = np.zeros((2,))\n",
    "    acce_binarys = np.zeros((window_size,), dtype=int)\n",
    "    acce_mag_pre = 0\n",
    "    state_flag = 0\n",
    "\n",
    "    warmup_data = np.ones((window_size,)) * 9.81\n",
    "    filter_b, filter_a, filter_zf = init_parameters_filter(sample_freq, warmup_data)\n",
    "    acce_mag_window = np.zeros((window_size, 1))\n",
    "\n",
    "    # detect steps according to acceleration magnitudes\n",
    "    for i in np.arange(0, np.size(acce_datas, 0)):\n",
    "        acce_data = acce_datas[i, :]\n",
    "        acce_mag = np.sqrt(np.sum(acce_data[1:] ** 2))\n",
    "\n",
    "        acce_mag_filt, filter_zf = signal.lfilter(filter_b, filter_a, [acce_mag], zi=filter_zf)\n",
    "        acce_mag_filt = acce_mag_filt[0]\n",
    "\n",
    "        acce_mag_window = np.append(acce_mag_window, [acce_mag_filt])\n",
    "        acce_mag_window = np.delete(acce_mag_window, 0)\n",
    "        mean_gravity = np.mean(acce_mag_window)\n",
    "        acce_std = np.std(acce_mag_window)\n",
    "        mag_threshold = np.max([low_acce_mag, 0.4 * acce_std])\n",
    "\n",
    "        # detect valid peak or valley of acceleration magnitudes\n",
    "        acce_mag_filt_detrend = acce_mag_filt - mean_gravity\n",
    "        if acce_mag_filt_detrend > np.max([acce_mag_pre, mag_threshold]):\n",
    "            # peak\n",
    "            acce_binarys = np.append(acce_binarys, [1])\n",
    "            acce_binarys = np.delete(acce_binarys, 0)\n",
    "        elif acce_mag_filt_detrend < np.min([acce_mag_pre, -mag_threshold]):\n",
    "            # valley\n",
    "            acce_binarys = np.append(acce_binarys, [-1])\n",
    "            acce_binarys = np.delete(acce_binarys, 0)\n",
    "        else:\n",
    "            # between peak and valley\n",
    "            acce_binarys = np.append(acce_binarys, [0])\n",
    "            acce_binarys = np.delete(acce_binarys, 0)\n",
    "\n",
    "        if (acce_binarys[-1] == 0) and (acce_binarys[-2] == 1):\n",
    "            if state_flag == 0:\n",
    "                acce_max[:] = acce_data[0], acce_mag_filt\n",
    "                state_flag = 1\n",
    "            elif (state_flag == 1) and ((acce_data[0] - acce_max[0]) <= interval_threshold) and (\n",
    "                    acce_mag_filt > acce_max[1]):\n",
    "                acce_max[:] = acce_data[0], acce_mag_filt\n",
    "            elif (state_flag == 2) and ((acce_data[0] - acce_max[0]) > interval_threshold):\n",
    "                acce_max[:] = acce_data[0], acce_mag_filt\n",
    "                state_flag = 1\n",
    "\n",
    "        # choose reasonable step criterion and check if there is a valid step\n",
    "        # save step acceleration data: step_acce_max_mins = [timestamp, max, min, variance]\n",
    "        step_flag = False\n",
    "        if step_criterion == 2:\n",
    "            if (acce_binarys[-1] == -1) and ((acce_binarys[-2] == 1) or (acce_binarys[-2] == 0)):\n",
    "                step_flag = True\n",
    "        elif step_criterion == 3:\n",
    "            if (acce_binarys[-1] == -1) and (acce_binarys[-2] == 0) and (np.sum(acce_binarys[:-2]) > 1):\n",
    "                step_flag = True\n",
    "        else:\n",
    "            if (acce_binarys[-1] == 0) and acce_binarys[-2] == -1:\n",
    "                if (state_flag == 1) and ((acce_data[0] - acce_min[0]) > interval_threshold):\n",
    "                    acce_min[:] = acce_data[0], acce_mag_filt\n",
    "                    state_flag = 2\n",
    "                    step_flag = True\n",
    "                elif (state_flag == 2) and ((acce_data[0] - acce_min[0]) <= interval_threshold) and (\n",
    "                        acce_mag_filt < acce_min[1]):\n",
    "                    acce_min[:] = acce_data[0], acce_mag_filt\n",
    "        if step_flag:\n",
    "            step_timestamps = np.append(step_timestamps, acce_data[0])\n",
    "            step_indexs = np.append(step_indexs, [i])\n",
    "            step_acce_max_mins = np.append(step_acce_max_mins,\n",
    "                                           [[acce_data[0], acce_max[1], acce_min[1], acce_std ** 2]], axis=0)\n",
    "        acce_mag_pre = acce_mag_filt_detrend\n",
    "\n",
    "    return step_timestamps, step_indexs, step_acce_max_mins\n",
    "\n",
    "\n",
    "def compute_stride_length(step_acce_max_mins):\n",
    "    K = 0.4\n",
    "    K_max = 0.8\n",
    "    K_min = 0.4\n",
    "    para_a0 = 0.21468084\n",
    "    para_a1 = 0.09154517\n",
    "    para_a2 = 0.02301998\n",
    "\n",
    "    stride_lengths = np.zeros((step_acce_max_mins.shape[0], 2))\n",
    "    k_real = np.zeros((step_acce_max_mins.shape[0], 2))\n",
    "    step_timeperiod = np.zeros((step_acce_max_mins.shape[0] - 1, ))\n",
    "    stride_lengths[:, 0] = step_acce_max_mins[:, 0]\n",
    "    window_size = 2\n",
    "    step_timeperiod_temp = np.zeros((0, ))\n",
    "\n",
    "    # calculate every step period - step_timeperiod unit: second\n",
    "    for i in range(0, step_timeperiod.shape[0]):\n",
    "        step_timeperiod_data = (step_acce_max_mins[i + 1, 0] - step_acce_max_mins[i, 0]) / 1000\n",
    "        step_timeperiod_temp = np.append(step_timeperiod_temp, [step_timeperiod_data])\n",
    "        if step_timeperiod_temp.shape[0] > window_size:\n",
    "            step_timeperiod_temp = np.delete(step_timeperiod_temp, [0])\n",
    "        step_timeperiod[i] = np.sum(step_timeperiod_temp) / step_timeperiod_temp.shape[0]\n",
    "\n",
    "    # calculate parameters by step period and acceleration magnitude variance\n",
    "    k_real[:, 0] = step_acce_max_mins[:, 0]\n",
    "    k_real[0, 1] = K\n",
    "    for i in range(0, step_timeperiod.shape[0]):\n",
    "        k_real[i + 1, 1] = np.max([(para_a0 + para_a1 / step_timeperiod[i] + para_a2 * step_acce_max_mins[i, 3]), K_min])\n",
    "        k_real[i + 1, 1] = np.min([k_real[i + 1, 1], K_max]) * (K / K_min)\n",
    "\n",
    "    # calculate every stride length by parameters and max and min data of acceleration magnitude\n",
    "    stride_lengths[:, 1] = np.max([(step_acce_max_mins[:, 1] - step_acce_max_mins[:, 2]),\n",
    "                                   np.ones((step_acce_max_mins.shape[0], ))], axis=0)**(1 / 4) * k_real[:, 1]\n",
    "\n",
    "    return stride_lengths\n",
    "\n",
    "\n",
    "def compute_headings(ahrs_datas):\n",
    "    headings = np.zeros((np.size(ahrs_datas, 0), 2))\n",
    "    for i in np.arange(0, np.size(ahrs_datas, 0)):\n",
    "        ahrs_data = ahrs_datas[i, :]\n",
    "        rot_mat = get_rotation_matrix_from_vector(ahrs_data[1:])\n",
    "        azimuth, pitch, roll = get_orientation(rot_mat)\n",
    "        around_z = (-azimuth) % (2 * np.pi)\n",
    "        headings[i, :] = ahrs_data[0], around_z\n",
    "    return headings\n",
    "\n",
    "\n",
    "def compute_step_heading(step_timestamps, headings):\n",
    "    step_headings = np.zeros((len(step_timestamps), 2))\n",
    "    step_timestamps_index = 0\n",
    "    for i in range(0, len(headings)):\n",
    "        if step_timestamps_index < len(step_timestamps):\n",
    "            if headings[i, 0] == step_timestamps[step_timestamps_index]:\n",
    "                step_headings[step_timestamps_index, :] = headings[i, :]\n",
    "                step_timestamps_index += 1\n",
    "        else:\n",
    "            break\n",
    "    assert step_timestamps_index == len(step_timestamps)\n",
    "\n",
    "    return step_headings\n",
    "\n",
    "\n",
    "def compute_rel_positions(stride_lengths, step_headings):\n",
    "    rel_positions = np.zeros((stride_lengths.shape[0], 3))\n",
    "    for i in range(0, stride_lengths.shape[0]):\n",
    "        rel_positions[i, 0] = stride_lengths[i, 0]\n",
    "        rel_positions[i, 1] = -stride_lengths[i, 1] * np.sin(step_headings[i, 1])\n",
    "        rel_positions[i, 2] = stride_lengths[i, 1] * np.cos(step_headings[i, 1])\n",
    "\n",
    "    return rel_positions\n",
    "\n",
    "\n",
    "def compute_step_positions(acce_datas, ahrs_datas, posi_datas):\n",
    "    step_timestamps, step_indexs, step_acce_max_mins = compute_steps(acce_datas)\n",
    "    headings = compute_headings(ahrs_datas)\n",
    "    stride_lengths = compute_stride_length(step_acce_max_mins)\n",
    "    step_headings = compute_step_heading(step_timestamps, headings)\n",
    "    rel_positions = compute_rel_positions(stride_lengths, step_headings)\n",
    "    step_positions = correct_positions(rel_positions, posi_datas)\n",
    "\n",
    "    return step_positions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission = pd.read_csv('../input/indoor-location-navigation/sample_submission.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>building</th>\n",
       "      <th>path_id</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">5a0546857ecc773753327266</th>\n",
       "      <th>046cfa46be49fc10834815c6</th>\n",
       "      <td>[0000000000009, 0000000009017, 0000000015326, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>05d052dde78384b0c543d89c</th>\n",
       "      <td>[0000000000012, 0000000005748, 0000000014654, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0c06cc9f21d172618d74c6c8</th>\n",
       "      <td>[0000000000011, 0000000011818, 0000000019825, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146035943a1482883ed98570</th>\n",
       "      <td>[0000000000011, 0000000004535, 0000000011498, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1ef2771dfea25d508142ba06</th>\n",
       "      <td>[0000000000009, 0000000012833, 0000000021759, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                           timestamp\n",
       "building                 path_id                                                                    \n",
       "5a0546857ecc773753327266 046cfa46be49fc10834815c6  [0000000000009, 0000000009017, 0000000015326, ...\n",
       "                         05d052dde78384b0c543d89c  [0000000000012, 0000000005748, 0000000014654, ...\n",
       "                         0c06cc9f21d172618d74c6c8  [0000000000011, 0000000011818, 0000000019825, ...\n",
       "                         146035943a1482883ed98570  [0000000000011, 0000000004535, 0000000011498, ...\n",
       "                         1ef2771dfea25d508142ba06  [0000000000009, 0000000012833, 0000000021759, ..."
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_submission['building'] = [x.split('_')[0] for x in sample_submission['site_path_timestamp']]\n",
    "sample_submission['path_id'] = [x.split('_')[1] for x in sample_submission['site_path_timestamp']]\n",
    "sample_submission['timestamp'] = [x.split('_')[2] for x in sample_submission['site_path_timestamp']]\n",
    "samples = pd.DataFrame(sample_submission.groupby(['building','path_id'])['timestamp'].apply(lambda x: list(x)))\n",
    "buildings = np.unique([x[0] for x in samples.index])\n",
    "samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5a0546857ecc773753327266\n",
      "5c3c44b80379370013e0fd2b\n",
      "5d27075f03f801723c2e360f\n",
      "5d27096c03f801723c31e5e0\n",
      "5d27097f03f801723c320d97\n",
      "5d27099f03f801723c32511d\n",
      "5d2709a003f801723c3251bf\n",
      "5d2709b303f801723c327472\n",
      "5d2709bb03f801723c32852c\n",
      "5d2709c303f801723c3299ee\n",
      "5d2709d403f801723c32bd39\n",
      "5d2709e003f801723c32d896\n",
      "5da138274db8ce0c98bbd3d2\n",
      "5da1382d4db8ce0c98bbe92e\n",
      "5da138314db8ce0c98bbf3a0\n",
      "5da138364db8ce0c98bc00f1\n",
      "5da1383b4db8ce0c98bc11ab\n",
      "5da138754db8ce0c98bca82f\n",
      "5da138764db8ce0c98bcaa46\n",
      "5da1389e4db8ce0c98bd0547\n",
      "5da138b74db8ce0c98bd4774\n",
      "5da958dd46f8266d0737457b\n",
      "5dbc1d84c1eb61796cf7c010\n",
      "5dc8cea7659e181adb076a3f\n"
     ]
    }
   ],
   "source": [
    "from scipy.interpolate import interp1d\n",
    "from scipy.ndimage.filters import uniform_filter1d\n",
    "\n",
    "colacce = ['xyz_time','x_acce','y_acce','z_acce']\n",
    "colahrs = ['xyz_time','x_ahrs','y_ahrs','z_ahrs']\n",
    "\n",
    "for building in buildings:\n",
    "    print(building)\n",
    "    paths = samples.loc[building].index\n",
    "    # Acceleration info:\n",
    "    tfm = pd.read_csv(f'indoor_testing_accel/{building}.txt',index_col=0)\n",
    "    for path_id in paths:\n",
    "        # Original predicted values:\n",
    "        xy = result.loc[building+'_'+path_id]\n",
    "        tfmi = tfm.loc[path_id]\n",
    "        acce_datas = np.array(tfmi[colacce],dtype=np.float)\n",
    "        ahrs_datas = np.array(tfmi[colahrs],dtype=np.float)\n",
    "        posi_datas = np.array(xy[['t1_wifi','x','y']],dtype=np.float)\n",
    "        # Outlier removal:\n",
    "        xyout = uniform_filter1d(posi_datas,size=3,axis=0,mode='reflect')\n",
    "        xydiff = np.abs(posi_datas-xyout)\n",
    "        xystd = np.std(xydiff,axis=0)*3\n",
    "        posi_datas = posi_datas[(xydiff[:,1]<xystd[1])&(xydiff[:,2]<xystd[2])]\n",
    "        # Step detection:\n",
    "        step_timestamps, step_indexs, step_acce_max_mins = compute_steps(acce_datas)\n",
    "        stride_lengths = compute_stride_length(step_acce_max_mins)\n",
    "        # Orientation detection:\n",
    "        headings = compute_headings(ahrs_datas)\n",
    "        step_headings = compute_step_heading(step_timestamps, headings)\n",
    "        rel_positions = compute_rel_positions(stride_lengths, step_headings)\n",
    "        # Running average:\n",
    "        posi_datas = uniform_filter1d(posi_datas,size=3,axis=0,mode='reflect')[0::3,:]\n",
    "        # The 1st prediction timepoint should be earlier than the 1st step timepoint.\n",
    "        rel_positions = rel_positions[rel_positions[:,0]>posi_datas[0,0],:]\n",
    "        # If two consecutive predictions are in-between two step datapoints,\n",
    "        # the last one is removed, causing error (in the \"split_ts_seq\" function).\n",
    "        posi_index = [np.searchsorted(rel_positions[:,0], x, side='right') for x in posi_datas[:,0]]\n",
    "        u, i1, i2 = np.unique(posi_index, return_index=True, return_inverse=True)\n",
    "        posi_datas = np.vstack([np.mean(posi_datas[i2==i],axis=0) for i in np.unique(i2)])\n",
    "        # Position correction:\n",
    "        step_positions = correct_positions(rel_positions, posi_datas)\n",
    "        # Interpolate for timestamps in the testing set:\n",
    "\n",
    "        t = step_positions[:,0]\n",
    "        x = step_positions[:,1]\n",
    "        y = step_positions[:,2]\n",
    "        fx = interp1d(t, x, kind='linear', fill_value=(x[0],x[-1]), bounds_error=False) #fill_value=\"extrapolate\"\n",
    "        fy = interp1d(t, y, kind='linear', fill_value=(y[0],y[-1]), bounds_error=False)\n",
    "        # Output result:\n",
    "        t0 = np.array(samples.loc[(building,path_id),'timestamp'],dtype=np.float64)\n",
    "        sample_submission.loc[(sample_submission.building==building)&(sample_submission.path_id==path_id),'x'] = fx(t0)\n",
    "        sample_submission.loc[(sample_submission.building==building)&(sample_submission.path_id==path_id),'y'] = fy(t0)\n",
    "            \n",
    "        #sample_submission.loc[(sample_submission.building==building)&(sample_submission.path_id==path_id),'floor'] = floors.loc[building+'_'+path_id,'floor']\n",
    "#         break\n",
    "#     break\n",
    "\n",
    "# sample_submission[['site_path_timestamp','floor','x','y']].to_csv('submission_mix_v3.3_del_outlier.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "subold = pd.read_csv('submission_floor.csv')\n",
    "\n",
    "sample_submission['floor']=subold['floor']\n",
    "sample_submission[['site_path_timestamp','floor','x','y']].to_csv('submission_wifi_sensor.csv',index=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
